feat: merge with main
|
|
@ -1,80 +0,0 @@
|
|||
---
|
||||
name: Bug report
|
||||
about: Create a report to help us improve
|
||||
title: ''
|
||||
labels: ''
|
||||
assignees: ''
|
||||
---
|
||||
|
||||
# Bug Report
|
||||
|
||||
## Important Notes
|
||||
|
||||
- **Before submitting a bug report**: Please check the Issues or Discussions section to see if a similar issue or feature request has already been posted. It's likely we're already tracking it! If you’re unsure, start a discussion post first. This will help us efficiently focus on improving the project.
|
||||
|
||||
- **Collaborate respectfully**: We value a constructive attitude, so please be mindful of your communication. If negativity is part of your approach, our capacity to engage may be limited. We’re here to help if you’re open to learning and communicating positively. Remember, Open WebUI is a volunteer-driven project managed by a single maintainer and supported by contributors who also have full-time jobs. We appreciate your time and ask that you respect ours.
|
||||
|
||||
- **Contributing**: If you encounter an issue, we highly encourage you to submit a pull request or fork the project. We actively work to prevent contributor burnout to maintain the quality and continuity of Open WebUI.
|
||||
|
||||
- **Bug reproducibility**: If a bug cannot be reproduced with a `:main` or `:dev` Docker setup, or a pip install with Python 3.11, it may require additional help from the community. In such cases, we will move it to the "issues" Discussions section due to our limited resources. We encourage the community to assist with these issues. Remember, it’s not that the issue doesn’t exist; we need your help!
|
||||
|
||||
Note: Please remove the notes above when submitting your post. Thank you for your understanding and support!
|
||||
|
||||
---
|
||||
|
||||
## Installation Method
|
||||
|
||||
[Describe the method you used to install the project, e.g., git clone, Docker, pip, etc.]
|
||||
|
||||
## Environment
|
||||
|
||||
- **Open WebUI Version:** [e.g., v0.3.11]
|
||||
- **Ollama (if applicable):** [e.g., v0.2.0, v0.1.32-rc1]
|
||||
|
||||
- **Operating System:** [e.g., Windows 10, macOS Big Sur, Ubuntu 20.04]
|
||||
- **Browser (if applicable):** [e.g., Chrome 100.0, Firefox 98.0]
|
||||
|
||||
**Confirmation:**
|
||||
|
||||
- [ ] I have read and followed all the instructions provided in the README.md.
|
||||
- [ ] I am on the latest version of both Open WebUI and Ollama.
|
||||
- [ ] I have included the browser console logs.
|
||||
- [ ] I have included the Docker container logs.
|
||||
- [ ] I have provided the exact steps to reproduce the bug in the "Steps to Reproduce" section below.
|
||||
|
||||
## Expected Behavior:
|
||||
|
||||
[Describe what you expected to happen.]
|
||||
|
||||
## Actual Behavior:
|
||||
|
||||
[Describe what actually happened.]
|
||||
|
||||
## Description
|
||||
|
||||
**Bug Summary:**
|
||||
[Provide a brief but clear summary of the bug]
|
||||
|
||||
## Reproduction Details
|
||||
|
||||
**Steps to Reproduce:**
|
||||
[Outline the steps to reproduce the bug. Be as detailed as possible.]
|
||||
|
||||
## Logs and Screenshots
|
||||
|
||||
**Browser Console Logs:**
|
||||
[Include relevant browser console logs, if applicable]
|
||||
|
||||
**Docker Container Logs:**
|
||||
[Include relevant Docker container logs, if applicable]
|
||||
|
||||
**Screenshots/Screen Recordings (if applicable):**
|
||||
[Attach any relevant screenshots to help illustrate the issue]
|
||||
|
||||
## Additional Information
|
||||
|
||||
[Include any additional details that may help in understanding and reproducing the issue. This could include specific configurations, error messages, or anything else relevant to the bug.]
|
||||
|
||||
## Note
|
||||
|
||||
If the bug report is incomplete or does not follow the provided instructions, it may not be addressed. Please ensure that you have followed the steps outlined in the README.md and troubleshooting.md documents, and provide all necessary information for us to reproduce and address the issue. Thank you!
|
||||
|
|
@ -0,0 +1,144 @@
|
|||
name: Bug Report
|
||||
description: Create a detailed bug report to help us improve Open WebUI.
|
||||
title: 'issue: '
|
||||
labels: ['bug', 'triage']
|
||||
assignees: []
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
# Bug Report
|
||||
|
||||
## Important Notes
|
||||
|
||||
- **Before submitting a bug report**: Please check the [Issues](https://github.com/open-webui/open-webui/issues) or [Discussions](https://github.com/open-webui/open-webui/discussions) sections to see if a similar issue has already been reported. If unsure, start a discussion first, as this helps us efficiently focus on improving the project.
|
||||
|
||||
- **Respectful collaboration**: Open WebUI is a volunteer-driven project with a single maintainer and contributors who also have full-time jobs. Please be constructive and respectful in your communication.
|
||||
|
||||
- **Contributing**: If you encounter an issue, consider submitting a pull request or forking the project. We prioritize preventing contributor burnout to maintain Open WebUI's quality.
|
||||
|
||||
- **Bug Reproducibility**: If a bug cannot be reproduced using a `:main` or `:dev` Docker setup or with `pip install` on Python 3.11, community assistance may be required. In such cases, we will move it to the "[Issues](https://github.com/open-webui/open-webui/discussions/categories/issues)" Discussions section. Your help is appreciated!
|
||||
|
||||
- type: checkboxes
|
||||
id: issue-check
|
||||
attributes:
|
||||
label: Check Existing Issues
|
||||
description: Confirm that you’ve checked for existing reports before submitting a new one.
|
||||
options:
|
||||
- label: I have searched the existing issues and discussions.
|
||||
required: true
|
||||
|
||||
- type: dropdown
|
||||
id: installation-method
|
||||
attributes:
|
||||
label: Installation Method
|
||||
description: How did you install Open WebUI?
|
||||
options:
|
||||
- Git Clone
|
||||
- Pip Install
|
||||
- Docker
|
||||
- Other
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: input
|
||||
id: open-webui-version
|
||||
attributes:
|
||||
label: Open WebUI Version
|
||||
description: Specify the version (e.g., v0.3.11)
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: input
|
||||
id: ollama-version
|
||||
attributes:
|
||||
label: Ollama Version (if applicable)
|
||||
description: Specify the version (e.g., v0.2.0, or v0.1.32-rc1)
|
||||
validations:
|
||||
required: false
|
||||
|
||||
- type: input
|
||||
id: operating-system
|
||||
attributes:
|
||||
label: Operating System
|
||||
description: Specify the OS (e.g., Windows 10, macOS Sonoma, Ubuntu 22.04)
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: input
|
||||
id: browser
|
||||
attributes:
|
||||
label: Browser (if applicable)
|
||||
description: Specify the browser/version (e.g., Chrome 100.0, Firefox 98.0)
|
||||
validations:
|
||||
required: false
|
||||
|
||||
- type: checkboxes
|
||||
id: confirmation
|
||||
attributes:
|
||||
label: Confirmation
|
||||
description: Ensure the following prerequisites have been met.
|
||||
options:
|
||||
- label: I have read and followed all instructions in `README.md`.
|
||||
required: true
|
||||
- label: I am using the latest version of **both** Open WebUI and Ollama.
|
||||
required: true
|
||||
- label: I have checked the browser console logs.
|
||||
required: true
|
||||
- label: I have checked the Docker container logs.
|
||||
required: true
|
||||
- label: I have listed steps to reproduce the bug in detail.
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: expected-behavior
|
||||
attributes:
|
||||
label: Expected Behavior
|
||||
description: Describe what should have happened.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: actual-behavior
|
||||
attributes:
|
||||
label: Actual Behavior
|
||||
description: Describe what actually happened.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: reproduction-steps
|
||||
attributes:
|
||||
label: Steps to Reproduce
|
||||
description: Provide step-by-step instructions to reproduce the issue.
|
||||
placeholder: |
|
||||
1. Go to '...'
|
||||
2. Click on '...'
|
||||
3. Scroll down to '...'
|
||||
4. See the error message '...'
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: logs-screenshots
|
||||
attributes:
|
||||
label: Logs & Screenshots
|
||||
description: Include relevant logs, errors, or screenshots to help diagnose the issue.
|
||||
placeholder: 'Attach logs from the browser console, Docker logs, or error messages.'
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: additional-info
|
||||
attributes:
|
||||
label: Additional Information
|
||||
description: Provide any extra details that may assist in understanding the issue.
|
||||
validations:
|
||||
required: false
|
||||
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
## Note
|
||||
If the bug report is incomplete or does not follow instructions, it may not be addressed. Ensure that you've followed all the **README.md** and **troubleshooting.md** guidelines, and provide all necessary information for us to reproduce the issue.
|
||||
Thank you for contributing to Open WebUI!
|
||||
|
|
@ -1,35 +0,0 @@
|
|||
---
|
||||
name: Feature request
|
||||
about: Suggest an idea for this project
|
||||
title: ''
|
||||
labels: ''
|
||||
assignees: ''
|
||||
---
|
||||
|
||||
# Feature Request
|
||||
|
||||
## Important Notes
|
||||
|
||||
- **Before submitting a report**: Please check the Issues or Discussions section to see if a similar issue or feature request has already been posted. It's likely we're already tracking it! If you’re unsure, start a discussion post first. This will help us efficiently focus on improving the project.
|
||||
|
||||
- **Collaborate respectfully**: We value a constructive attitude, so please be mindful of your communication. If negativity is part of your approach, our capacity to engage may be limited. We’re here to help if you’re open to learning and communicating positively. Remember, Open WebUI is a volunteer-driven project managed by a single maintainer and supported by contributors who also have full-time jobs. We appreciate your time and ask that you respect ours.
|
||||
|
||||
- **Contributing**: If you encounter an issue, we highly encourage you to submit a pull request or fork the project. We actively work to prevent contributor burnout to maintain the quality and continuity of Open WebUI.
|
||||
|
||||
- **Bug reproducibility**: If a bug cannot be reproduced with a `:main` or `:dev` Docker setup, or a pip install with Python 3.11, it may require additional help from the community. In such cases, we will move it to the "issues" Discussions section due to our limited resources. We encourage the community to assist with these issues. Remember, it’s not that the issue doesn’t exist; we need your help!
|
||||
|
||||
Note: Please remove the notes above when submitting your post. Thank you for your understanding and support!
|
||||
|
||||
---
|
||||
|
||||
**Is your feature request related to a problem? Please describe.**
|
||||
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
|
||||
|
||||
**Describe the solution you'd like**
|
||||
A clear and concise description of what you want to happen.
|
||||
|
||||
**Describe alternatives you've considered**
|
||||
A clear and concise description of any alternative solutions or features you've considered.
|
||||
|
||||
**Additional context**
|
||||
Add any other context or screenshots about the feature request here.
|
||||
|
|
@ -0,0 +1,64 @@
|
|||
name: Feature Request
|
||||
description: Suggest an idea for this project
|
||||
title: 'feat: '
|
||||
labels: ['triage']
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
## Important Notes
|
||||
### Before submitting
|
||||
Please check the [Issues](https://github.com/open-webui/open-webui/issues) or [Discussions](https://github.com/open-webui/open-webui/discussions) to see if a similar request has been posted.
|
||||
It's likely we're already tracking it! If you’re unsure, start a discussion post first.
|
||||
This will help us efficiently focus on improving the project.
|
||||
|
||||
### Collaborate respectfully
|
||||
We value a **constructive attitude**, so please be mindful of your communication. If negativity is part of your approach, our capacity to engage may be limited. We're here to help if you're **open to learning** and **communicating positively**.
|
||||
|
||||
Remember:
|
||||
- Open WebUI is a **volunteer-driven project**
|
||||
- It's managed by a **single maintainer**
|
||||
- It's supported by contributors who also have **full-time jobs**
|
||||
|
||||
We appreciate your time and ask that you **respect ours**.
|
||||
|
||||
|
||||
### Contributing
|
||||
If you encounter an issue, we highly encourage you to submit a pull request or fork the project. We actively work to prevent contributor burnout to maintain the quality and continuity of Open WebUI.
|
||||
|
||||
### Bug reproducibility
|
||||
If a bug cannot be reproduced with a `:main` or `:dev` Docker setup, or a `pip install` with Python 3.11, it may require additional help from the community. In such cases, we will move it to the "[issues](https://github.com/open-webui/open-webui/discussions/categories/issues)" Discussions section due to our limited resources. We encourage the community to assist with these issues. Remember, it’s not that the issue doesn’t exist; we need your help!
|
||||
|
||||
- type: checkboxes
|
||||
id: existing-issue
|
||||
attributes:
|
||||
label: Check Existing Issues
|
||||
description: Please confirm that you've checked for existing similar requests
|
||||
options:
|
||||
- label: I have searched the existing issues and discussions.
|
||||
required: true
|
||||
- type: textarea
|
||||
id: problem-description
|
||||
attributes:
|
||||
label: Problem Description
|
||||
description: Is your feature request related to a problem? Please provide a clear and concise description of what the problem is.
|
||||
placeholder: "Ex. I'm always frustrated when..."
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: solution-description
|
||||
attributes:
|
||||
label: Desired Solution you'd like
|
||||
description: Clearly describe what you want to happen.
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: alternatives-considered
|
||||
attributes:
|
||||
label: Alternatives Considered
|
||||
description: A clear and concise description of any alternative solutions or features you've considered.
|
||||
- type: textarea
|
||||
id: additional-context
|
||||
attributes:
|
||||
label: Additional Context
|
||||
description: Add any other context or screenshots about the feature request here.
|
||||
|
|
@ -14,7 +14,7 @@ env:
|
|||
|
||||
jobs:
|
||||
build-main-image:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ${{ matrix.platform == 'linux/arm64' && 'ubuntu-24.04-arm' || 'ubuntu-latest' }}
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
|
|
@ -111,7 +111,7 @@ jobs:
|
|||
retention-days: 1
|
||||
|
||||
build-cuda-image:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ${{ matrix.platform == 'linux/arm64' && 'ubuntu-24.04-arm' || 'ubuntu-latest' }}
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
|
|
@ -211,7 +211,7 @@ jobs:
|
|||
retention-days: 1
|
||||
|
||||
build-ollama-image:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ${{ matrix.platform == 'linux/arm64' && 'ubuntu-24.04-arm' || 'ubuntu-latest' }}
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ jobs:
|
|||
uses: actions/checkout@v4
|
||||
- uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 18
|
||||
node-version: 22
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: 3.11
|
||||
|
|
|
|||
179
CHANGELOG.md
|
|
@ -5,6 +5,185 @@ All notable changes to this project will be documented in this file.
|
|||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [0.5.19] - 2024-03-04
|
||||
|
||||
### Added
|
||||
|
||||
- **📊 Logit Bias Parameter Support**: Fine-tune conversation dynamics by adjusting the Logit Bias parameter directly in chat settings, giving you more control over model responses.
|
||||
- **⌨️ Customizable Enter Behavior**: You can now configure Enter to send messages only when combined with Ctrl (Ctrl+Enter) via Settings > Interface, preventing accidental message sends.
|
||||
- **📝 Collapsible Code Blocks**: Easily collapse long code blocks to declutter your chat, making it easier to focus on important details.
|
||||
- **🏷️ Tag Selector in Model Selector**: Quickly find and categorize models with the new tag filtering system in the Model Selector, streamlining model discovery.
|
||||
- **📈 Experimental Elasticsearch Vector DB Support**: Now supports Elasticsearch as a vector database, offering more flexibility for data retrieval in Retrieval-Augmented Generation (RAG) workflows.
|
||||
- **⚙️ General Reliability Enhancements**: Various stability improvements across the WebUI, ensuring a smoother, more consistent experience.
|
||||
- **🌍 Updated Translations**: Refined multilingual support for better localization and accuracy across various languages.
|
||||
|
||||
### Fixed
|
||||
|
||||
- **🔄 "Stream" Hook Activation**: Fixed an issue where the "Stream" hook only worked when globally enabled, ensuring reliable real-time filtering.
|
||||
- **📧 LDAP Email Case Sensitivity**: Resolved an issue where LDAP login failed due to email case sensitivity mismatches, improving authentication reliability.
|
||||
- **💬 WebSocket Chat Event Registration**: Fixed a bug preventing chat event listeners from being registered upon sign-in, ensuring real-time updates work properly.
|
||||
|
||||
## [0.5.18] - 2025-02-27
|
||||
|
||||
### Fixed
|
||||
|
||||
- **🌐 Open WebUI Now Works Over LAN in Insecure Context**: Resolved an issue preventing Open WebUI from functioning when accessed over a local network in an insecure context, ensuring seamless connectivity.
|
||||
- **🔄 UI Now Reflects Deleted Connections Instantly**: Fixed an issue where deleting a connection did not update the UI in real time, ensuring accurate system state visibility.
|
||||
- **🛠️ Models Now Display Correctly with ENABLE_FORWARD_USER_INFO_HEADERS**: Addressed a bug where models were not visible when ENABLE_FORWARD_USER_INFO_HEADERS was set, restoring proper model listing.
|
||||
|
||||
## [0.5.17] - 2025-02-27
|
||||
|
||||
### Added
|
||||
|
||||
- **🚀 Instant Document Upload with Bypass Embedding & Retrieval**: Admins can now enable "Bypass Embedding & Retrieval" in Admin Settings > Documents, significantly speeding up document uploads and ensuring full document context is retained without chunking.
|
||||
- **🔎 "Stream" Hook for Real-Time Filtering**: The new "stream" hook allows dynamic real-time message filtering. Learn more in our documentation (https://docs.openwebui.com/features/plugin/functions/filter).
|
||||
- **☁️ OneDrive Integration**: Early support for OneDrive storage integration has been introduced, expanding file import options.
|
||||
- **📈 Enhanced Logging with Loguru**: Backend logging has been improved with Loguru, making debugging and issue tracking far more efficient.
|
||||
- **⚙️ General Stability Enhancements**: Backend and frontend refactoring improves performance, ensuring a smoother and more reliable user experience.
|
||||
- **🌍 Updated Translations**: Refined multilingual support for better localization and accuracy across various languages.
|
||||
|
||||
### Fixed
|
||||
|
||||
- **🔄 Reliable Model Imports from the Community Platform**: Resolved import failures, allowing seamless integration of community-shared models without errors.
|
||||
- **📊 OpenAI Usage Statistics Restored**: Fixed an issue where OpenAI usage metrics were not displaying correctly, ensuring accurate tracking of usage data.
|
||||
- **🗂️ Deduplication for Retrieved Documents**: Documents retrieved during searches are now intelligently deduplicated, meaning no more redundant results—helping to keep information concise and relevant.
|
||||
|
||||
### Changed
|
||||
|
||||
- **📝 "Full Context Mode" Renamed for Clarity**: The "Full Context Mode" toggle in Web Search settings is now labeled "Bypass Embedding & Retrieval" for consistency across the UI.
|
||||
|
||||
## [0.5.16] - 2025-02-20
|
||||
|
||||
### Fixed
|
||||
|
||||
- **🔍 Web Search Retrieval Restored**: Resolved a critical issue that broke web search retrieval by reverting deduplication changes, ensuring complete and accurate search results once again.
|
||||
|
||||
## [0.5.15] - 2025-02-20
|
||||
|
||||
### Added
|
||||
|
||||
- **📄 Full Context Mode for Local Document Search (RAG)**: Toggle full context mode from Admin Settings > Documents to inject entire document content into context, improving accuracy for models with large context windows—ideal for deep context understanding.
|
||||
- **🌍 Smarter Web Search with Agentic Workflows**: Web searches now intelligently gather and refine multiple relevant terms, similar to RAG handling, delivering significantly better search results for more accurate information retrieval.
|
||||
- **🔎 Experimental Playwright Support for Web Loader**: Web content retrieval is taken to the next level with Playwright-powered scraping for enhanced accuracy in extracted web data.
|
||||
- **☁️ Experimental Azure Storage Provider**: Early-stage support for Azure Storage allows more cloud storage flexibility directly within Open WebUI.
|
||||
- **📊 Improved Jupyter Code Execution with Plots**: Interactive coding now properly displays inline plots, making data visualization more seamless inside chat interactions.
|
||||
- **⏳ Adjustable Execution Timeout for Jupyter Interpreter**: Customize execution timeout (default: 60s) for Jupyter-based code execution, allowing longer or more constrained execution based on your needs.
|
||||
- **▶️ "Running..." Indicator for Jupyter Code Execution**: A visual indicator now appears while code execution is in progress, providing real-time status updates on ongoing computations.
|
||||
- **⚙️ General Backend & Frontend Stability Enhancements**: Extensive refactoring improves reliability, performance, and overall user experience for a more seamless Open WebUI.
|
||||
- **🌍 Translation Updates**: Various international translation refinements ensure better localization and a more natural user interface experience.
|
||||
|
||||
### Fixed
|
||||
|
||||
- **📱 Mobile Hover Issue Resolved**: Users can now edit responses smoothly on mobile without interference, fixing a longstanding hover issue.
|
||||
- **🔄 Temporary Chat Message Duplication Fixed**: Eliminated buggy behavior where messages were being unnecessarily repeated in temporary chat mode, ensuring a smooth and consistent conversation flow.
|
||||
|
||||
## [0.5.14] - 2025-02-17
|
||||
|
||||
### Fixed
|
||||
|
||||
- **🔧 Critical Import Error Resolved**: Fixed a circular import issue preventing 'override_static' from being correctly imported in 'open_webui.config', ensuring smooth system initialization and stability.
|
||||
|
||||
## [0.5.13] - 2025-02-17
|
||||
|
||||
### Added
|
||||
|
||||
- **🌐 Full Context Mode for Web Search**: Enable highly accurate web searches by utilizing full context mode—ideal for models with large context windows, ensuring more precise and insightful results.
|
||||
- **⚡ Optimized Asynchronous Web Search**: Web searches now load significantly faster with optimized async support, providing users with quicker, more efficient information retrieval.
|
||||
- **🔄 Auto Text Direction for RTL Languages**: Automatic text alignment based on language input, ensuring seamless conversation flow for Arabic, Hebrew, and other right-to-left scripts.
|
||||
- **🚀 Jupyter Notebook Support for Code Execution**: The "Run" button in code blocks can now use Jupyter for execution, offering a powerful, dynamic coding experience directly in the chat.
|
||||
- **🗑️ Message Delete Confirmation Dialog**: Prevent accidental deletions with a new confirmation prompt before removing messages, adding an additional layer of security to your chat history.
|
||||
- **📥 Download Button for SVG Diagrams**: SVG diagrams generated within chat can now be downloaded instantly, making it easier to save and share complex visual data.
|
||||
- **✨ General UI/UX Improvements and Backend Stability**: A refined interface with smoother interactions, improved layouts, and backend stability enhancements for a more reliable, polished experience.
|
||||
|
||||
### Fixed
|
||||
|
||||
- **🛠️ Temporary Chat Message Continue Button Fixed**: The "Continue Response" button for temporary chats now works as expected, ensuring an uninterrupted conversation flow.
|
||||
|
||||
### Changed
|
||||
|
||||
- **📝 Prompt Variable Update**: Deprecated square bracket '[]' indicators for prompt variables; now requires double curly brackets '{{}}' for consistency and clarity.
|
||||
- **🔧 Stability Enhancements**: Error handling improved in chat history, ensuring smoother operations when reviewing previous messages.
|
||||
|
||||
## [0.5.12] - 2025-02-13
|
||||
|
||||
### Added
|
||||
|
||||
- **🛠️ Multiple Tool Calls Support for Native Function Mode**: Functions now can call multiple tools within a single response, unlocking better automation and workflow flexibility when using native function calling.
|
||||
|
||||
### Fixed
|
||||
|
||||
- **📝 Playground Text Completion Restored**: Addressed an issue where text completion in the Playground was not functioning.
|
||||
- **🔗 Direct Connections Now Work for Regular Users**: Fixed a bug where users with the 'user' role couldn't establish direct API connections, enabling seamless model usage for all user tiers.
|
||||
- **⚡ Landing Page Input No Longer Lags with Long Text**: Improved input responsiveness on the landing page, ensuring fast and smooth typing experiences even when entering long messages.
|
||||
- **🔧 Parameter in Functions Fixed**: Fixed an issue where the reserved parameters wasn’t recognized within functions, restoring full functionality for advanced task-based automation.
|
||||
|
||||
## [0.5.11] - 2025-02-13
|
||||
|
||||
### Added
|
||||
|
||||
- **🎤 Kokoro-JS TTS Support**: A new on-device, high-quality text-to-speech engine has been integrated, vastly improving voice generation quality—everything runs directly in your browser.
|
||||
- **🐍 Jupyter Notebook Support in Code Interpreter**: Now, you can configure Code Interpreter to run Python code not only via Pyodide but also through Jupyter, offering a more robust coding environment for AI-driven computations and analysis.
|
||||
- **🔗 Direct API Connections for Private & Local Inference**: You can now connect Open WebUI to your private or localhost API inference endpoints. CORS must be enabled, but this unlocks direct, on-device AI infrastructure support.
|
||||
- **🔍 Advanced Domain Filtering for Web Search**: You can now specify which domains should be included or excluded from web searches, refining results for more relevant information retrieval.
|
||||
- **🚀 Improved Image Generation Metadata Handling**: Generated images now retain metadata for better organization and future retrieval.
|
||||
- **📂 S3 Key Prefix Support**: Fine-grained control over S3 storage file structuring with configurable key prefixes.
|
||||
- **📸 Support for Image-Only Messages**: Send messages containing only images, facilitating more visual-centric interactions.
|
||||
- **🌍 Updated Translations**: German, Spanish, Traditional Chinese, and Catalan translations updated for better multilingual support.
|
||||
|
||||
### Fixed
|
||||
|
||||
- **🔧 OAuth Debug Logs & Username Claim Fixes**: Debug logs have been added for OAuth role and group management, with fixes ensuring proper OAuth username retrieval and claim handling.
|
||||
- **📌 Citations Formatting & Toggle Fixes**: Inline citation toggles now function correctly, and citations with more than three sources are now fully visible when expanded.
|
||||
- **📸 ComfyUI Maximum Seed Value Constraint Fixed**: The maximum allowed seed value for ComfyUI has been corrected, preventing unintended behavior.
|
||||
- **🔑 Connection Settings Stability**: Addressed connection settings issues that were causing instability when saving configurations.
|
||||
- **📂 GGUF Model Upload Stability**: Fixed upload inconsistencies for GGUF models, ensuring reliable local model handling.
|
||||
- **🔧 Web Search Configuration Bug**: Fixed issues where web search filters and settings weren't correctly applied.
|
||||
- **💾 User Settings Persistence Fix**: Ensured user-specific settings are correctly saved and applied across sessions.
|
||||
- **🔄 OpenID Username Retrieval Enhancement**: Usernames are now correctly picked up and assigned for OpenID Connect (OIDC) logins.
|
||||
|
||||
## [0.5.10] - 2025-02-05
|
||||
|
||||
### Fixed
|
||||
|
||||
- **⚙️ System Prompts Now Properly Templated via API**: Resolved an issue where system prompts were not being correctly processed when used through the API, ensuring template variables now function as expected.
|
||||
- **📝 '<thinking>' Tag Display Issue Fixed**: Fixed a bug where the 'thinking' tag was disrupting content rendering, ensuring clean and accurate text display.
|
||||
- **💻 Code Interpreter Stability with Custom Functions**: Addressed failures when using the Code Interpreter with certain custom functions like Anthropic, ensuring smoother execution and better compatibility.
|
||||
|
||||
## [0.5.9] - 2025-02-05
|
||||
|
||||
### Fixed
|
||||
|
||||
- **💡 "Think" Tag Display Issue**: Resolved a bug where the "Think" tag was not functioning correctly, ensuring proper visualization of the model's reasoning process before delivering responses.
|
||||
|
||||
## [0.5.8] - 2025-02-05
|
||||
|
||||
### Added
|
||||
|
||||
- **🖥️ Code Interpreter**: Models can now execute code in real time to refine their answers dynamically, running securely within a sandboxed browser environment using Pyodide. Perfect for calculations, data analysis, and AI-assisted coding tasks!
|
||||
- **💬 Redesigned Chat Input UI**: Enjoy a sleeker and more intuitive message input with improved feature selection, making it easier than ever to toggle tools, enable search, and interact with AI seamlessly.
|
||||
- **🛠️ Native Tool Calling Support (Experimental)**: Supported models can now call tools natively, reducing query latency and improving contextual responses. More enhancements coming soon!
|
||||
- **🔗 Exa Search Engine Integration**: A new search provider has been added, allowing users to retrieve up-to-date and relevant information without leaving the chat interface.
|
||||
- **🌍 Localized Dates & Times**: Date and time formats now match your system locale, ensuring a more natural, region-specific experience.
|
||||
- **📎 User Headers for External Embedding APIs**: API calls to external embedding services now include user-related headers.
|
||||
- **🌍 "Always On" Web Search Toggle**: A new option under Settings > Interface allows users to enable Web Search by default—transform Open WebUI into your go-to search engine, ensuring AI-powered results with every query.
|
||||
- **🚀 General Performance & Stability**: Significant improvements across the platform for a faster, more reliable experience.
|
||||
- **🖼️ UI/UX Enhancements**: Numerous design refinements improving readability, responsiveness, and accessibility.
|
||||
- **🌍 Improved Translations**: Chinese, Korean, French, Ukrainian and Serbian translations have been updated with refined terminologies for better clarity.
|
||||
|
||||
### Fixed
|
||||
|
||||
- **🔄 OAuth Name Field Fallback**: Resolves OAuth login failures by using the email field as a fallback when a name is missing.
|
||||
- **🔑 Google Drive Credentials Restriction**: Ensures only authenticated users can access Google Drive credentials for enhanced security.
|
||||
- **🌐 DuckDuckGo Search Rate Limit Handling**: Fixes issues where users would encounter 202 errors due to rate limits when using DuckDuckGo for web search.
|
||||
- **📁 File Upload Permission Indicator**: Users are now notified when they lack permission to upload files, improving clarity on system restrictions.
|
||||
- **🔧 Max Tokens Issue**: Fixes cases where 'max_tokens' were not applied correctly, ensuring proper model behavior.
|
||||
- **🔍 Validation for RAG Web Search URLs**: Filters out invalid or unsupported URLs when using web-based retrieval augmentation.
|
||||
- **🖋️ Title Generation Bug**: Fixes inconsistencies in title generation, ensuring proper chat organization.
|
||||
|
||||
### Removed
|
||||
|
||||
- **⚡ Deprecated Non-Web Worker Pyodide Execution**: Moves entirely to browser sandboxing for better performance and security.
|
||||
|
||||
## [0.5.7] - 2025-01-23
|
||||
|
||||
### Added
|
||||
|
|
|
|||
11
README.md
|
|
@ -27,10 +27,15 @@ git push origin main
|
|||
|
||||
**Open WebUI is an [extensible](https://docs.openwebui.com/features/plugin/), feature-rich, and user-friendly self-hosted AI platform designed to operate entirely offline.** It supports various LLM runners like **Ollama** and **OpenAI-compatible APIs**, with **built-in inference engine** for RAG, making it a **powerful AI deployment solution**.
|
||||
|
||||
For more information, be sure to check out our [Open WebUI Documentation](https://docs.openwebui.com/).
|
||||
|
||||

|
||||
|
||||
> [!TIP]
|
||||
> **Looking for an [Enterprise Plan](https://docs.openwebui.com/enterprise)?** – **[Speak with Our Sales Team Today!](mailto:sales@openwebui.com)**
|
||||
>
|
||||
> Get **enhanced capabilities**, including **custom theming and branding**, **Service Level Agreement (SLA) support**, **Long-Term Support (LTS) versions**, and **more!**
|
||||
|
||||
For more information, be sure to check out our [Open WebUI Documentation](https://docs.openwebui.com/).
|
||||
|
||||
## Key Features of Open WebUI ⭐
|
||||
|
||||
- 🚀 **Effortless Setup**: Install seamlessly using Docker or Kubernetes (kubectl, kustomize or helm) for a hassle-free experience with support for both `:ollama` and `:cuda` tagged images.
|
||||
|
|
@ -188,7 +193,7 @@ docker run --rm --volume /var/run/docker.sock:/var/run/docker.sock containrrr/wa
|
|||
|
||||
In the last part of the command, replace `open-webui` with your container name if it is different.
|
||||
|
||||
Check our Migration Guide available in our [Open WebUI Documentation](https://docs.openwebui.com/tutorials/migration/).
|
||||
Check our Updating Guide available in our [Open WebUI Documentation](https://docs.openwebui.com/getting-started/updating).
|
||||
|
||||
### Using the Dev Branch 🌙
|
||||
|
||||
|
|
|
|||
|
|
@ -2,12 +2,13 @@ import json
|
|||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import base64
|
||||
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Generic, Optional, TypeVar
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import chromadb
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import JSON, Column, DateTime, Integer, func
|
||||
|
|
@ -42,7 +43,7 @@ logging.getLogger("uvicorn.access").addFilter(EndpointFilter())
|
|||
|
||||
# Function to run the alembic migrations
|
||||
def run_migrations():
|
||||
print("Running migrations")
|
||||
log.info("Running migrations")
|
||||
try:
|
||||
from alembic import command
|
||||
from alembic.config import Config
|
||||
|
|
@ -55,7 +56,7 @@ def run_migrations():
|
|||
|
||||
command.upgrade(alembic_cfg, "head")
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
log.exception(f"Error running migrations: {e}")
|
||||
|
||||
|
||||
run_migrations()
|
||||
|
|
@ -586,6 +587,14 @@ load_oauth_providers()
|
|||
|
||||
STATIC_DIR = Path(os.getenv("STATIC_DIR", OPEN_WEBUI_DIR / "static")).resolve()
|
||||
|
||||
for file_path in (FRONTEND_BUILD_DIR / "static").glob("**/*"):
|
||||
if file_path.is_file():
|
||||
target_path = STATIC_DIR / file_path.relative_to(
|
||||
(FRONTEND_BUILD_DIR / "static")
|
||||
)
|
||||
target_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copyfile(file_path, target_path)
|
||||
|
||||
frontend_favicon = FRONTEND_BUILD_DIR / "static" / "favicon.png"
|
||||
|
||||
if frontend_favicon.exists():
|
||||
|
|
@ -593,8 +602,6 @@ if frontend_favicon.exists():
|
|||
shutil.copyfile(frontend_favicon, STATIC_DIR / "favicon.png")
|
||||
except Exception as e:
|
||||
logging.error(f"An error occurred: {e}")
|
||||
else:
|
||||
logging.warning(f"Frontend favicon not found at {frontend_favicon}")
|
||||
|
||||
frontend_splash = FRONTEND_BUILD_DIR / "static" / "splash.png"
|
||||
|
||||
|
|
@ -603,12 +610,18 @@ if frontend_splash.exists():
|
|||
shutil.copyfile(frontend_splash, STATIC_DIR / "splash.png")
|
||||
except Exception as e:
|
||||
logging.error(f"An error occurred: {e}")
|
||||
else:
|
||||
logging.warning(f"Frontend splash not found at {frontend_splash}")
|
||||
|
||||
frontend_loader = FRONTEND_BUILD_DIR / "static" / "loader.js"
|
||||
|
||||
if frontend_loader.exists():
|
||||
try:
|
||||
shutil.copyfile(frontend_loader, STATIC_DIR / "loader.js")
|
||||
except Exception as e:
|
||||
logging.error(f"An error occurred: {e}")
|
||||
|
||||
|
||||
####################################
|
||||
# CUSTOM_NAME
|
||||
# CUSTOM_NAME (Legacy)
|
||||
####################################
|
||||
|
||||
CUSTOM_NAME = os.environ.get("CUSTOM_NAME", "")
|
||||
|
|
@ -650,6 +663,12 @@ if CUSTOM_NAME:
|
|||
pass
|
||||
|
||||
|
||||
####################################
|
||||
# LICENSE_KEY
|
||||
####################################
|
||||
|
||||
LICENSE_KEY = os.environ.get("LICENSE_KEY", "")
|
||||
|
||||
####################################
|
||||
# STORAGE PROVIDER
|
||||
####################################
|
||||
|
|
@ -660,27 +679,47 @@ S3_ACCESS_KEY_ID = os.environ.get("S3_ACCESS_KEY_ID", None)
|
|||
S3_SECRET_ACCESS_KEY = os.environ.get("S3_SECRET_ACCESS_KEY", None)
|
||||
S3_REGION_NAME = os.environ.get("S3_REGION_NAME", None)
|
||||
S3_BUCKET_NAME = os.environ.get("S3_BUCKET_NAME", None)
|
||||
S3_KEY_PREFIX = os.environ.get("S3_KEY_PREFIX", None)
|
||||
S3_ENDPOINT_URL = os.environ.get("S3_ENDPOINT_URL", None)
|
||||
S3_USE_ACCELERATE_ENDPOINT = (
|
||||
os.environ.get("S3_USE_ACCELERATE_ENDPOINT", "False").lower() == "true"
|
||||
)
|
||||
S3_ADDRESSING_STYLE = os.environ.get("S3_ADDRESSING_STYLE", None)
|
||||
|
||||
GCS_BUCKET_NAME = os.environ.get("GCS_BUCKET_NAME", None)
|
||||
GOOGLE_APPLICATION_CREDENTIALS_JSON = os.environ.get(
|
||||
"GOOGLE_APPLICATION_CREDENTIALS_JSON", None
|
||||
)
|
||||
|
||||
AZURE_STORAGE_ENDPOINT = os.environ.get("AZURE_STORAGE_ENDPOINT", None)
|
||||
AZURE_STORAGE_CONTAINER_NAME = os.environ.get("AZURE_STORAGE_CONTAINER_NAME", None)
|
||||
AZURE_STORAGE_KEY = os.environ.get("AZURE_STORAGE_KEY", None)
|
||||
|
||||
####################################
|
||||
# File Upload DIR
|
||||
####################################
|
||||
|
||||
UPLOAD_DIR = f"{DATA_DIR}/uploads"
|
||||
Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True)
|
||||
UPLOAD_DIR = DATA_DIR / "uploads"
|
||||
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
####################################
|
||||
# Cache DIR
|
||||
####################################
|
||||
|
||||
CACHE_DIR = f"{DATA_DIR}/cache"
|
||||
Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
|
||||
CACHE_DIR = DATA_DIR / "cache"
|
||||
CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
####################################
|
||||
# DIRECT CONNECTIONS
|
||||
####################################
|
||||
|
||||
ENABLE_DIRECT_CONNECTIONS = PersistentConfig(
|
||||
"ENABLE_DIRECT_CONNECTIONS",
|
||||
"direct.enable",
|
||||
os.environ.get("ENABLE_DIRECT_CONNECTIONS", "True").lower() == "true",
|
||||
)
|
||||
|
||||
####################################
|
||||
# OLLAMA_BASE_URL
|
||||
|
|
@ -755,6 +794,9 @@ ENABLE_OPENAI_API = PersistentConfig(
|
|||
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
|
||||
OPENAI_API_BASE_URL = os.environ.get("OPENAI_API_BASE_URL", "")
|
||||
|
||||
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY", "")
|
||||
GEMINI_API_BASE_URL = os.environ.get("GEMINI_API_BASE_URL", "")
|
||||
|
||||
|
||||
if OPENAI_API_BASE_URL == "":
|
||||
OPENAI_API_BASE_URL = "https://api.openai.com/v1"
|
||||
|
|
@ -927,6 +969,12 @@ USER_PERMISSIONS_FEATURES_IMAGE_GENERATION = (
|
|||
== "true"
|
||||
)
|
||||
|
||||
USER_PERMISSIONS_FEATURES_CODE_INTERPRETER = (
|
||||
os.environ.get("USER_PERMISSIONS_FEATURES_CODE_INTERPRETER", "True").lower()
|
||||
== "true"
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_USER_PERMISSIONS = {
|
||||
"workspace": {
|
||||
"models": USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS,
|
||||
|
|
@ -944,6 +992,7 @@ DEFAULT_USER_PERMISSIONS = {
|
|||
"features": {
|
||||
"web_search": USER_PERMISSIONS_FEATURES_WEB_SEARCH,
|
||||
"image_generation": USER_PERMISSIONS_FEATURES_IMAGE_GENERATION,
|
||||
"code_interpreter": USER_PERMISSIONS_FEATURES_CODE_INTERPRETER,
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -1052,7 +1101,7 @@ try:
|
|||
banners = json.loads(os.environ.get("WEBUI_BANNERS", "[]"))
|
||||
banners = [BannerModel(**banner) for banner in banners]
|
||||
except Exception as e:
|
||||
print(f"Error loading WEBUI_BANNERS: {e}")
|
||||
log.exception(f"Error loading WEBUI_BANNERS: {e}")
|
||||
banners = []
|
||||
|
||||
WEBUI_BANNERS = PersistentConfig("WEBUI_BANNERS", "ui.banners", banners)
|
||||
|
|
@ -1094,21 +1143,27 @@ TITLE_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
|
|||
os.environ.get("TITLE_GENERATION_PROMPT_TEMPLATE", ""),
|
||||
)
|
||||
|
||||
DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE = """Create a concise, 3-5 word title with an emoji as a title for the chat history, in the given language. Suitable Emojis for the summary can be used to enhance understanding but avoid quotation marks or special formatting. RESPOND ONLY WITH THE TITLE TEXT.
|
||||
|
||||
Examples of titles:
|
||||
📉 Stock Market Trends
|
||||
🍪 Perfect Chocolate Chip Recipe
|
||||
Evolution of Music Streaming
|
||||
Remote Work Productivity Tips
|
||||
Artificial Intelligence in Healthcare
|
||||
🎮 Video Game Development Insights
|
||||
|
||||
DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE = """### Task:
|
||||
Generate a concise, 3-5 word title with an emoji summarizing the chat history.
|
||||
### Guidelines:
|
||||
- The title should clearly represent the main theme or subject of the conversation.
|
||||
- Use emojis that enhance understanding of the topic, but avoid quotation marks or special formatting.
|
||||
- Write the title in the chat's primary language; default to English if multilingual.
|
||||
- Prioritize accuracy over excessive creativity; keep it clear and simple.
|
||||
### Output:
|
||||
JSON format: { "title": "your concise title here" }
|
||||
### Examples:
|
||||
- { "title": "📉 Stock Market Trends" },
|
||||
- { "title": "🍪 Perfect Chocolate Chip Recipe" },
|
||||
- { "title": "Evolution of Music Streaming" },
|
||||
- { "title": "Remote Work Productivity Tips" },
|
||||
- { "title": "Artificial Intelligence in Healthcare" },
|
||||
- { "title": "🎮 Video Game Development Insights" }
|
||||
### Chat History:
|
||||
<chat_history>
|
||||
{{MESSAGES:END:2}}
|
||||
</chat_history>"""
|
||||
|
||||
|
||||
TAGS_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
|
||||
"TAGS_GENERATION_PROMPT_TEMPLATE",
|
||||
"task.tags.prompt_template",
|
||||
|
|
@ -1165,6 +1220,12 @@ ENABLE_TAGS_GENERATION = PersistentConfig(
|
|||
os.environ.get("ENABLE_TAGS_GENERATION", "True").lower() == "true",
|
||||
)
|
||||
|
||||
ENABLE_TITLE_GENERATION = PersistentConfig(
|
||||
"ENABLE_TITLE_GENERATION",
|
||||
"task.title.enable",
|
||||
os.environ.get("ENABLE_TITLE_GENERATION", "True").lower() == "true",
|
||||
)
|
||||
|
||||
|
||||
ENABLE_SEARCH_QUERY_GENERATION = PersistentConfig(
|
||||
"ENABLE_SEARCH_QUERY_GENERATION",
|
||||
|
|
@ -1277,7 +1338,28 @@ TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig(
|
|||
)
|
||||
|
||||
|
||||
DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = """Available Tools: {{TOOLS}}\nReturn an empty string if no tools match the query. If a function tool matches, construct and return a JSON object in the format {\"name\": \"functionName\", \"parameters\": {\"requiredFunctionParamKey\": \"requiredFunctionParamValue\"}} using the appropriate tool and its parameters. Only return the object and limit the response to the JSON object without additional text."""
|
||||
DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = """Available Tools: {{TOOLS}}
|
||||
|
||||
Your task is to choose and return the correct tool(s) from the list of available tools based on the query. Follow these guidelines:
|
||||
|
||||
- Return only the JSON object, without any additional text or explanation.
|
||||
|
||||
- If no tools match the query, return an empty array:
|
||||
{
|
||||
"tool_calls": []
|
||||
}
|
||||
|
||||
- If one or more tools match the query, construct a JSON response containing a "tool_calls" array with objects that include:
|
||||
- "name": The tool's name.
|
||||
- "parameters": A dictionary of required parameters and their corresponding values.
|
||||
|
||||
The format for the JSON response is strictly:
|
||||
{
|
||||
"tool_calls": [
|
||||
{"name": "toolName1", "parameters": {"key1": "value1"}},
|
||||
{"name": "toolName2", "parameters": {"key2": "value2"}}
|
||||
]
|
||||
}"""
|
||||
|
||||
|
||||
DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE = """Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱).
|
||||
|
|
@ -1290,6 +1372,131 @@ Your task is to synthesize these responses into a single, high-quality response.
|
|||
|
||||
Responses from models: {{responses}}"""
|
||||
|
||||
|
||||
####################################
|
||||
# Code Interpreter
|
||||
####################################
|
||||
|
||||
|
||||
CODE_EXECUTION_ENGINE = PersistentConfig(
|
||||
"CODE_EXECUTION_ENGINE",
|
||||
"code_execution.engine",
|
||||
os.environ.get("CODE_EXECUTION_ENGINE", "pyodide"),
|
||||
)
|
||||
|
||||
CODE_EXECUTION_JUPYTER_URL = PersistentConfig(
|
||||
"CODE_EXECUTION_JUPYTER_URL",
|
||||
"code_execution.jupyter.url",
|
||||
os.environ.get("CODE_EXECUTION_JUPYTER_URL", ""),
|
||||
)
|
||||
|
||||
CODE_EXECUTION_JUPYTER_AUTH = PersistentConfig(
|
||||
"CODE_EXECUTION_JUPYTER_AUTH",
|
||||
"code_execution.jupyter.auth",
|
||||
os.environ.get("CODE_EXECUTION_JUPYTER_AUTH", ""),
|
||||
)
|
||||
|
||||
CODE_EXECUTION_JUPYTER_AUTH_TOKEN = PersistentConfig(
|
||||
"CODE_EXECUTION_JUPYTER_AUTH_TOKEN",
|
||||
"code_execution.jupyter.auth_token",
|
||||
os.environ.get("CODE_EXECUTION_JUPYTER_AUTH_TOKEN", ""),
|
||||
)
|
||||
|
||||
|
||||
CODE_EXECUTION_JUPYTER_AUTH_PASSWORD = PersistentConfig(
|
||||
"CODE_EXECUTION_JUPYTER_AUTH_PASSWORD",
|
||||
"code_execution.jupyter.auth_password",
|
||||
os.environ.get("CODE_EXECUTION_JUPYTER_AUTH_PASSWORD", ""),
|
||||
)
|
||||
|
||||
CODE_EXECUTION_JUPYTER_TIMEOUT = PersistentConfig(
|
||||
"CODE_EXECUTION_JUPYTER_TIMEOUT",
|
||||
"code_execution.jupyter.timeout",
|
||||
int(os.environ.get("CODE_EXECUTION_JUPYTER_TIMEOUT", "60")),
|
||||
)
|
||||
|
||||
ENABLE_CODE_INTERPRETER = PersistentConfig(
|
||||
"ENABLE_CODE_INTERPRETER",
|
||||
"code_interpreter.enable",
|
||||
os.environ.get("ENABLE_CODE_INTERPRETER", "True").lower() == "true",
|
||||
)
|
||||
|
||||
CODE_INTERPRETER_ENGINE = PersistentConfig(
|
||||
"CODE_INTERPRETER_ENGINE",
|
||||
"code_interpreter.engine",
|
||||
os.environ.get("CODE_INTERPRETER_ENGINE", "pyodide"),
|
||||
)
|
||||
|
||||
CODE_INTERPRETER_PROMPT_TEMPLATE = PersistentConfig(
|
||||
"CODE_INTERPRETER_PROMPT_TEMPLATE",
|
||||
"code_interpreter.prompt_template",
|
||||
os.environ.get("CODE_INTERPRETER_PROMPT_TEMPLATE", ""),
|
||||
)
|
||||
|
||||
CODE_INTERPRETER_JUPYTER_URL = PersistentConfig(
|
||||
"CODE_INTERPRETER_JUPYTER_URL",
|
||||
"code_interpreter.jupyter.url",
|
||||
os.environ.get(
|
||||
"CODE_INTERPRETER_JUPYTER_URL", os.environ.get("CODE_EXECUTION_JUPYTER_URL", "")
|
||||
),
|
||||
)
|
||||
|
||||
CODE_INTERPRETER_JUPYTER_AUTH = PersistentConfig(
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH",
|
||||
"code_interpreter.jupyter.auth",
|
||||
os.environ.get(
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH",
|
||||
os.environ.get("CODE_EXECUTION_JUPYTER_AUTH", ""),
|
||||
),
|
||||
)
|
||||
|
||||
CODE_INTERPRETER_JUPYTER_AUTH_TOKEN = PersistentConfig(
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH_TOKEN",
|
||||
"code_interpreter.jupyter.auth_token",
|
||||
os.environ.get(
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH_TOKEN",
|
||||
os.environ.get("CODE_EXECUTION_JUPYTER_AUTH_TOKEN", ""),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = PersistentConfig(
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD",
|
||||
"code_interpreter.jupyter.auth_password",
|
||||
os.environ.get(
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD",
|
||||
os.environ.get("CODE_EXECUTION_JUPYTER_AUTH_PASSWORD", ""),
|
||||
),
|
||||
)
|
||||
|
||||
CODE_INTERPRETER_JUPYTER_TIMEOUT = PersistentConfig(
|
||||
"CODE_INTERPRETER_JUPYTER_TIMEOUT",
|
||||
"code_interpreter.jupyter.timeout",
|
||||
int(
|
||||
os.environ.get(
|
||||
"CODE_INTERPRETER_JUPYTER_TIMEOUT",
|
||||
os.environ.get("CODE_EXECUTION_JUPYTER_TIMEOUT", "60"),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_CODE_INTERPRETER_PROMPT = """
|
||||
#### Tools Available
|
||||
|
||||
1. **Code Interpreter**: `<code_interpreter type="code" lang="python"></code_interpreter>`
|
||||
- You have access to a Python shell that runs directly in the user's browser, enabling fast execution of code for analysis, calculations, or problem-solving. Use it in this response.
|
||||
- The Python code you write can incorporate a wide array of libraries, handle data manipulation or visualization, perform API calls for web-related tasks, or tackle virtually any computational challenge. Use this flexibility to **think outside the box, craft elegant solutions, and harness Python's full potential**.
|
||||
- To use it, **you must enclose your code within `<code_interpreter type="code" lang="python">` XML tags** and stop right away. If you don't, the code won't execute. Do NOT use triple backticks.
|
||||
- When coding, **always aim to print meaningful outputs** (e.g., results, tables, summaries, or visuals) to better interpret and verify the findings. Avoid relying on implicit outputs; prioritize explicit and clear print statements so the results are effectively communicated to the user.
|
||||
- After obtaining the printed output, **always provide a concise analysis, interpretation, or next steps to help the user understand the findings or refine the outcome further.**
|
||||
- If the results are unclear, unexpected, or require validation, refine the code and execute it again as needed. Always aim to deliver meaningful insights from the results, iterating if necessary.
|
||||
- **If a link to an image, audio, or any file is provided in markdown format in the output, ALWAYS regurgitate word for word, explicitly display it as part of the response to ensure the user can access it easily, do NOT change the link.**
|
||||
- All responses should be communicated in the chat's primary language, ensuring seamless understanding. If the chat is multilingual, default to English for clarity.
|
||||
|
||||
Ensure that the tools are effectively utilized to achieve the highest-quality analysis for the user."""
|
||||
|
||||
|
||||
####################################
|
||||
# Vector Database
|
||||
####################################
|
||||
|
|
@ -1298,27 +1505,34 @@ VECTOR_DB = os.environ.get("VECTOR_DB", "chroma")
|
|||
|
||||
# Chroma
|
||||
CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db"
|
||||
CHROMA_TENANT = os.environ.get("CHROMA_TENANT", chromadb.DEFAULT_TENANT)
|
||||
CHROMA_DATABASE = os.environ.get("CHROMA_DATABASE", chromadb.DEFAULT_DATABASE)
|
||||
CHROMA_HTTP_HOST = os.environ.get("CHROMA_HTTP_HOST", "")
|
||||
CHROMA_HTTP_PORT = int(os.environ.get("CHROMA_HTTP_PORT", "8000"))
|
||||
CHROMA_CLIENT_AUTH_PROVIDER = os.environ.get("CHROMA_CLIENT_AUTH_PROVIDER", "")
|
||||
CHROMA_CLIENT_AUTH_CREDENTIALS = os.environ.get("CHROMA_CLIENT_AUTH_CREDENTIALS", "")
|
||||
# Comma-separated list of header=value pairs
|
||||
CHROMA_HTTP_HEADERS = os.environ.get("CHROMA_HTTP_HEADERS", "")
|
||||
if CHROMA_HTTP_HEADERS:
|
||||
CHROMA_HTTP_HEADERS = dict(
|
||||
[pair.split("=") for pair in CHROMA_HTTP_HEADERS.split(",")]
|
||||
|
||||
if VECTOR_DB == "chroma":
|
||||
import chromadb
|
||||
|
||||
CHROMA_TENANT = os.environ.get("CHROMA_TENANT", chromadb.DEFAULT_TENANT)
|
||||
CHROMA_DATABASE = os.environ.get("CHROMA_DATABASE", chromadb.DEFAULT_DATABASE)
|
||||
CHROMA_HTTP_HOST = os.environ.get("CHROMA_HTTP_HOST", "")
|
||||
CHROMA_HTTP_PORT = int(os.environ.get("CHROMA_HTTP_PORT", "8000"))
|
||||
CHROMA_CLIENT_AUTH_PROVIDER = os.environ.get("CHROMA_CLIENT_AUTH_PROVIDER", "")
|
||||
CHROMA_CLIENT_AUTH_CREDENTIALS = os.environ.get(
|
||||
"CHROMA_CLIENT_AUTH_CREDENTIALS", ""
|
||||
)
|
||||
else:
|
||||
CHROMA_HTTP_HEADERS = None
|
||||
CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true"
|
||||
# Comma-separated list of header=value pairs
|
||||
CHROMA_HTTP_HEADERS = os.environ.get("CHROMA_HTTP_HEADERS", "")
|
||||
if CHROMA_HTTP_HEADERS:
|
||||
CHROMA_HTTP_HEADERS = dict(
|
||||
[pair.split("=") for pair in CHROMA_HTTP_HEADERS.split(",")]
|
||||
)
|
||||
else:
|
||||
CHROMA_HTTP_HEADERS = None
|
||||
CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true"
|
||||
# this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2)
|
||||
|
||||
# Milvus
|
||||
|
||||
MILVUS_URI = os.environ.get("MILVUS_URI", f"{DATA_DIR}/vector_db/milvus.db")
|
||||
MILVUS_DB = os.environ.get("MILVUS_DB", "default")
|
||||
MILVUS_TOKEN = os.environ.get("MILVUS_TOKEN", None)
|
||||
|
||||
# Qdrant
|
||||
QDRANT_URI = os.environ.get("QDRANT_URI", None)
|
||||
|
|
@ -1331,6 +1545,15 @@ OPENSEARCH_CERT_VERIFY = os.environ.get("OPENSEARCH_CERT_VERIFY", False)
|
|||
OPENSEARCH_USERNAME = os.environ.get("OPENSEARCH_USERNAME", None)
|
||||
OPENSEARCH_PASSWORD = os.environ.get("OPENSEARCH_PASSWORD", None)
|
||||
|
||||
# ElasticSearch
|
||||
ELASTICSEARCH_URL = os.environ.get("ELASTICSEARCH_URL", "https://localhost:9200")
|
||||
ELASTICSEARCH_CA_CERTS = os.environ.get("ELASTICSEARCH_CA_CERTS", None)
|
||||
ELASTICSEARCH_API_KEY = os.environ.get("ELASTICSEARCH_API_KEY", None)
|
||||
ELASTICSEARCH_USERNAME = os.environ.get("ELASTICSEARCH_USERNAME", None)
|
||||
ELASTICSEARCH_PASSWORD = os.environ.get("ELASTICSEARCH_PASSWORD", None)
|
||||
ELASTICSEARCH_CLOUD_ID = os.environ.get("ELASTICSEARCH_CLOUD_ID", None)
|
||||
SSL_ASSERT_FINGERPRINT = os.environ.get("SSL_ASSERT_FINGERPRINT", None)
|
||||
|
||||
# Pgvector
|
||||
PGVECTOR_DB_URL = os.environ.get("PGVECTOR_DB_URL", DATABASE_URL)
|
||||
if VECTOR_DB == "pgvector" and not PGVECTOR_DB_URL.startswith("postgres"):
|
||||
|
|
@ -1365,6 +1588,18 @@ GOOGLE_DRIVE_API_KEY = PersistentConfig(
|
|||
os.environ.get("GOOGLE_DRIVE_API_KEY", ""),
|
||||
)
|
||||
|
||||
ENABLE_ONEDRIVE_INTEGRATION = PersistentConfig(
|
||||
"ENABLE_ONEDRIVE_INTEGRATION",
|
||||
"onedrive.enable",
|
||||
os.getenv("ENABLE_ONEDRIVE_INTEGRATION", "False").lower() == "true",
|
||||
)
|
||||
|
||||
ONEDRIVE_CLIENT_ID = PersistentConfig(
|
||||
"ONEDRIVE_CLIENT_ID",
|
||||
"onedrive.client_id",
|
||||
os.environ.get("ONEDRIVE_CLIENT_ID", ""),
|
||||
)
|
||||
|
||||
# RAG Content Extraction
|
||||
CONTENT_EXTRACTION_ENGINE = PersistentConfig(
|
||||
"CONTENT_EXTRACTION_ENGINE",
|
||||
|
|
@ -1384,6 +1619,26 @@ DOCLING_SERVER_URL = PersistentConfig(
|
|||
os.getenv("DOCLING_SERVER_URL", "http://docling:5001"),
|
||||
)
|
||||
|
||||
DOCUMENT_INTELLIGENCE_ENDPOINT = PersistentConfig(
|
||||
"DOCUMENT_INTELLIGENCE_ENDPOINT",
|
||||
"rag.document_intelligence_endpoint",
|
||||
os.getenv("DOCUMENT_INTELLIGENCE_ENDPOINT", ""),
|
||||
)
|
||||
|
||||
DOCUMENT_INTELLIGENCE_KEY = PersistentConfig(
|
||||
"DOCUMENT_INTELLIGENCE_KEY",
|
||||
"rag.document_intelligence_key",
|
||||
os.getenv("DOCUMENT_INTELLIGENCE_KEY", ""),
|
||||
)
|
||||
|
||||
|
||||
BYPASS_EMBEDDING_AND_RETRIEVAL = PersistentConfig(
|
||||
"BYPASS_EMBEDDING_AND_RETRIEVAL",
|
||||
"rag.bypass_embedding_and_retrieval",
|
||||
os.environ.get("BYPASS_EMBEDDING_AND_RETRIEVAL", "False").lower() == "true",
|
||||
)
|
||||
|
||||
|
||||
RAG_TOP_K = PersistentConfig(
|
||||
"RAG_TOP_K", "rag.top_k", int(os.environ.get("RAG_TOP_K", "3"))
|
||||
)
|
||||
|
|
@ -1399,6 +1654,12 @@ ENABLE_RAG_HYBRID_SEARCH = PersistentConfig(
|
|||
os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true",
|
||||
)
|
||||
|
||||
RAG_FULL_CONTEXT = PersistentConfig(
|
||||
"RAG_FULL_CONTEXT",
|
||||
"rag.full_context",
|
||||
os.getenv("RAG_FULL_CONTEXT", "False").lower() == "true",
|
||||
)
|
||||
|
||||
RAG_FILE_MAX_COUNT = PersistentConfig(
|
||||
"RAG_FILE_MAX_COUNT",
|
||||
"rag.file.max_count",
|
||||
|
|
@ -1513,7 +1774,7 @@ Respond to the user query using the provided context, incorporating inline citat
|
|||
- Respond in the same language as the user's query.
|
||||
- If the context is unreadable or of poor quality, inform the user and provide the best possible answer.
|
||||
- If the answer isn't present in the context but you possess the knowledge, explain this to the user and provide the answer using your own understanding.
|
||||
- **Only include inline citations using [source_id] when a <source_id> tag is explicitly provided in the context.**
|
||||
- **Only include inline citations using [source_id] (e.g., [1], [2]) when a `<source_id>` tag is explicitly provided in the context.**
|
||||
- Do not cite if the <source_id> tag is not provided in the context.
|
||||
- Do not use XML tags in your response.
|
||||
- Ensure citations are concise and directly related to the information provided.
|
||||
|
|
@ -1594,11 +1855,17 @@ RAG_WEB_SEARCH_ENGINE = PersistentConfig(
|
|||
os.getenv("RAG_WEB_SEARCH_ENGINE", ""),
|
||||
)
|
||||
|
||||
BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = PersistentConfig(
|
||||
"BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL",
|
||||
"rag.web.search.bypass_embedding_and_retrieval",
|
||||
os.getenv("BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL", "False").lower() == "true",
|
||||
)
|
||||
|
||||
# You can provide a list of your own websites to filter after performing a web search.
|
||||
# This ensures the highest level of safety and reliability of the information sources.
|
||||
RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = PersistentConfig(
|
||||
"RAG_WEB_SEARCH_DOMAIN_FILTER_LIST",
|
||||
"rag.rag.web.search.domain.filter_list",
|
||||
"rag.web.search.domain.filter_list",
|
||||
[
|
||||
# "wikipedia.com",
|
||||
# "wikimedia.org",
|
||||
|
|
@ -1643,6 +1910,12 @@ MOJEEK_SEARCH_API_KEY = PersistentConfig(
|
|||
os.getenv("MOJEEK_SEARCH_API_KEY", ""),
|
||||
)
|
||||
|
||||
BOCHA_SEARCH_API_KEY = PersistentConfig(
|
||||
"BOCHA_SEARCH_API_KEY",
|
||||
"rag.web.search.bocha_search_api_key",
|
||||
os.getenv("BOCHA_SEARCH_API_KEY", ""),
|
||||
)
|
||||
|
||||
SERPSTACK_API_KEY = PersistentConfig(
|
||||
"SERPSTACK_API_KEY",
|
||||
"rag.web.search.serpstack_api_key",
|
||||
|
|
@ -1691,6 +1964,18 @@ SEARCHAPI_ENGINE = PersistentConfig(
|
|||
os.getenv("SEARCHAPI_ENGINE", ""),
|
||||
)
|
||||
|
||||
SERPAPI_API_KEY = PersistentConfig(
|
||||
"SERPAPI_API_KEY",
|
||||
"rag.web.search.serpapi_api_key",
|
||||
os.getenv("SERPAPI_API_KEY", ""),
|
||||
)
|
||||
|
||||
SERPAPI_ENGINE = PersistentConfig(
|
||||
"SERPAPI_ENGINE",
|
||||
"rag.web.search.serpapi_engine",
|
||||
os.getenv("SERPAPI_ENGINE", ""),
|
||||
)
|
||||
|
||||
BING_SEARCH_V7_ENDPOINT = PersistentConfig(
|
||||
"BING_SEARCH_V7_ENDPOINT",
|
||||
"rag.web.search.bing_search_v7_endpoint",
|
||||
|
|
@ -1705,6 +1990,17 @@ BING_SEARCH_V7_SUBSCRIPTION_KEY = PersistentConfig(
|
|||
os.environ.get("BING_SEARCH_V7_SUBSCRIPTION_KEY", ""),
|
||||
)
|
||||
|
||||
EXA_API_KEY = PersistentConfig(
|
||||
"EXA_API_KEY",
|
||||
"rag.web.search.exa_api_key",
|
||||
os.getenv("EXA_API_KEY", ""),
|
||||
)
|
||||
|
||||
PERPLEXITY_API_KEY = PersistentConfig(
|
||||
"PERPLEXITY_API_KEY",
|
||||
"rag.web.search.perplexity_api_key",
|
||||
os.getenv("PERPLEXITY_API_KEY", ""),
|
||||
)
|
||||
|
||||
RAG_WEB_SEARCH_RESULT_COUNT = PersistentConfig(
|
||||
"RAG_WEB_SEARCH_RESULT_COUNT",
|
||||
|
|
@ -1718,6 +2014,35 @@ RAG_WEB_SEARCH_CONCURRENT_REQUESTS = PersistentConfig(
|
|||
int(os.getenv("RAG_WEB_SEARCH_CONCURRENT_REQUESTS", "10")),
|
||||
)
|
||||
|
||||
RAG_WEB_LOADER_ENGINE = PersistentConfig(
|
||||
"RAG_WEB_LOADER_ENGINE",
|
||||
"rag.web.loader.engine",
|
||||
os.environ.get("RAG_WEB_LOADER_ENGINE", "safe_web"),
|
||||
)
|
||||
|
||||
RAG_WEB_SEARCH_TRUST_ENV = PersistentConfig(
|
||||
"RAG_WEB_SEARCH_TRUST_ENV",
|
||||
"rag.web.search.trust_env",
|
||||
os.getenv("RAG_WEB_SEARCH_TRUST_ENV", "False").lower() == "true",
|
||||
)
|
||||
|
||||
PLAYWRIGHT_WS_URI = PersistentConfig(
|
||||
"PLAYWRIGHT_WS_URI",
|
||||
"rag.web.loader.engine.playwright.ws.uri",
|
||||
os.environ.get("PLAYWRIGHT_WS_URI", None),
|
||||
)
|
||||
|
||||
FIRECRAWL_API_KEY = PersistentConfig(
|
||||
"FIRECRAWL_API_KEY",
|
||||
"firecrawl.api_key",
|
||||
os.environ.get("FIRECRAWL_API_KEY", ""),
|
||||
)
|
||||
|
||||
FIRECRAWL_API_BASE_URL = PersistentConfig(
|
||||
"FIRECRAWL_API_BASE_URL",
|
||||
"firecrawl.api_url",
|
||||
os.environ.get("FIRECRAWL_API_BASE_URL", "https://api.firecrawl.dev"),
|
||||
)
|
||||
|
||||
####################################
|
||||
# Images
|
||||
|
|
@ -1929,6 +2254,17 @@ IMAGES_OPENAI_API_KEY = PersistentConfig(
|
|||
os.getenv("IMAGES_OPENAI_API_KEY", OPENAI_API_KEY),
|
||||
)
|
||||
|
||||
IMAGES_GEMINI_API_BASE_URL = PersistentConfig(
|
||||
"IMAGES_GEMINI_API_BASE_URL",
|
||||
"image_generation.gemini.api_base_url",
|
||||
os.getenv("IMAGES_GEMINI_API_BASE_URL", GEMINI_API_BASE_URL),
|
||||
)
|
||||
IMAGES_GEMINI_API_KEY = PersistentConfig(
|
||||
"IMAGES_GEMINI_API_KEY",
|
||||
"image_generation.gemini.api_key",
|
||||
os.getenv("IMAGES_GEMINI_API_KEY", GEMINI_API_KEY),
|
||||
)
|
||||
|
||||
IMAGE_SIZE = PersistentConfig(
|
||||
"IMAGE_SIZE", "image_generation.size", os.getenv("IMAGE_SIZE", "512x512")
|
||||
)
|
||||
|
|
@ -1960,6 +2296,12 @@ WHISPER_MODEL_AUTO_UPDATE = (
|
|||
and os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Add Deepgram configuration
|
||||
DEEPGRAM_API_KEY = PersistentConfig(
|
||||
"DEEPGRAM_API_KEY",
|
||||
"audio.stt.deepgram.api_key",
|
||||
os.getenv("DEEPGRAM_API_KEY", ""),
|
||||
)
|
||||
|
||||
AUDIO_STT_OPENAI_API_BASE_URL = PersistentConfig(
|
||||
"AUDIO_STT_OPENAI_API_BASE_URL",
|
||||
|
|
@ -2099,7 +2441,7 @@ LDAP_SEARCH_BASE = PersistentConfig(
|
|||
LDAP_SEARCH_FILTERS = PersistentConfig(
|
||||
"LDAP_SEARCH_FILTER",
|
||||
"ldap.server.search_filter",
|
||||
os.environ.get("LDAP_SEARCH_FILTER", ""),
|
||||
os.environ.get("LDAP_SEARCH_FILTER", os.environ.get("LDAP_SEARCH_FILTERS", "")),
|
||||
)
|
||||
|
||||
LDAP_USE_TLS = PersistentConfig(
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ class ERROR_MESSAGES(str, Enum):
|
|||
)
|
||||
|
||||
FILE_NOT_SENT = "FILE_NOT_SENT"
|
||||
FILE_NOT_SUPPORTED = "Oops! It seems like the file format you're trying to upload is not supported. Please upload a file with a supported format (e.g., JPG, PNG, PDF, TXT) and try again."
|
||||
FILE_NOT_SUPPORTED = "Oops! It seems like the file format you're trying to upload is not supported. Please upload a file with a supported format and try again."
|
||||
|
||||
NOT_FOUND = "We could not find what you're looking for :/"
|
||||
USER_NOT_FOUND = "We could not find what you're looking for :/"
|
||||
|
|
|
|||
|
|
@ -65,10 +65,8 @@ except Exception:
|
|||
# LOGGING
|
||||
####################################
|
||||
|
||||
log_levels = ["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"]
|
||||
|
||||
GLOBAL_LOG_LEVEL = os.environ.get("GLOBAL_LOG_LEVEL", "").upper()
|
||||
if GLOBAL_LOG_LEVEL in log_levels:
|
||||
if GLOBAL_LOG_LEVEL in logging.getLevelNamesMapping():
|
||||
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL, force=True)
|
||||
else:
|
||||
GLOBAL_LOG_LEVEL = "INFO"
|
||||
|
|
@ -78,6 +76,7 @@ log.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}")
|
|||
|
||||
if "cuda_error" in locals():
|
||||
log.exception(cuda_error)
|
||||
del cuda_error
|
||||
|
||||
log_sources = [
|
||||
"AUDIO",
|
||||
|
|
@ -92,6 +91,7 @@ log_sources = [
|
|||
"RAG",
|
||||
"WEBHOOK",
|
||||
"SOCKET",
|
||||
"OAUTH",
|
||||
]
|
||||
|
||||
SRC_LOG_LEVELS = {}
|
||||
|
|
@ -99,7 +99,7 @@ SRC_LOG_LEVELS = {}
|
|||
for source in log_sources:
|
||||
log_env_var = source + "_LOG_LEVEL"
|
||||
SRC_LOG_LEVELS[source] = os.environ.get(log_env_var, "").upper()
|
||||
if SRC_LOG_LEVELS[source] not in log_levels:
|
||||
if SRC_LOG_LEVELS[source] not in logging.getLevelNamesMapping():
|
||||
SRC_LOG_LEVELS[source] = GLOBAL_LOG_LEVEL
|
||||
log.info(f"{log_env_var}: {SRC_LOG_LEVELS[source]}")
|
||||
|
||||
|
|
@ -112,6 +112,7 @@ if WEBUI_NAME != "Open WebUI":
|
|||
|
||||
WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png"
|
||||
|
||||
TRUSTED_SIGNATURE_KEY = os.environ.get("TRUSTED_SIGNATURE_KEY", "")
|
||||
|
||||
####################################
|
||||
# ENV (dev,test,prod)
|
||||
|
|
@ -356,14 +357,22 @@ WEBUI_SECRET_KEY = os.environ.get(
|
|||
), # DEPRECATED: remove at next major version
|
||||
)
|
||||
|
||||
WEBUI_SESSION_COOKIE_SAME_SITE = os.environ.get(
|
||||
"WEBUI_SESSION_COOKIE_SAME_SITE",
|
||||
os.environ.get("WEBUI_SESSION_COOKIE_SAME_SITE", "lax"),
|
||||
WEBUI_SESSION_COOKIE_SAME_SITE = os.environ.get("WEBUI_SESSION_COOKIE_SAME_SITE", "lax")
|
||||
|
||||
WEBUI_SESSION_COOKIE_SECURE = (
|
||||
os.environ.get("WEBUI_SESSION_COOKIE_SECURE", "false").lower() == "true"
|
||||
)
|
||||
|
||||
WEBUI_SESSION_COOKIE_SECURE = os.environ.get(
|
||||
"WEBUI_SESSION_COOKIE_SECURE",
|
||||
os.environ.get("WEBUI_SESSION_COOKIE_SECURE", "false").lower() == "true",
|
||||
WEBUI_AUTH_COOKIE_SAME_SITE = os.environ.get(
|
||||
"WEBUI_AUTH_COOKIE_SAME_SITE", WEBUI_SESSION_COOKIE_SAME_SITE
|
||||
)
|
||||
|
||||
WEBUI_AUTH_COOKIE_SECURE = (
|
||||
os.environ.get(
|
||||
"WEBUI_AUTH_COOKIE_SECURE",
|
||||
os.environ.get("WEBUI_SESSION_COOKIE_SECURE", "false"),
|
||||
).lower()
|
||||
== "true"
|
||||
)
|
||||
|
||||
if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
|
||||
|
|
@ -376,6 +385,7 @@ ENABLE_WEBSOCKET_SUPPORT = (
|
|||
WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "")
|
||||
|
||||
WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL)
|
||||
WEBSOCKET_REDIS_LOCK_TIMEOUT = os.environ.get("WEBSOCKET_REDIS_LOCK_TIMEOUT", 60)
|
||||
|
||||
AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "")
|
||||
|
||||
|
|
@ -387,19 +397,20 @@ else:
|
|||
except Exception:
|
||||
AIOHTTP_CLIENT_TIMEOUT = 300
|
||||
|
||||
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = os.environ.get(
|
||||
"AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", ""
|
||||
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST = os.environ.get(
|
||||
"AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST",
|
||||
os.environ.get("AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", ""),
|
||||
)
|
||||
|
||||
if AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST == "":
|
||||
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = None
|
||||
|
||||
if AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST == "":
|
||||
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST = None
|
||||
else:
|
||||
try:
|
||||
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = int(
|
||||
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST
|
||||
)
|
||||
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST = int(AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
|
||||
except Exception:
|
||||
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = 5
|
||||
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST = 5
|
||||
|
||||
|
||||
####################################
|
||||
# OFFLINE_MODE
|
||||
|
|
@ -409,3 +420,25 @@ OFFLINE_MODE = os.environ.get("OFFLINE_MODE", "false").lower() == "true"
|
|||
|
||||
if OFFLINE_MODE:
|
||||
os.environ["HF_HUB_OFFLINE"] = "1"
|
||||
|
||||
####################################
|
||||
# AUDIT LOGGING
|
||||
####################################
|
||||
ENABLE_AUDIT_LOGS = os.getenv("ENABLE_AUDIT_LOGS", "false").lower() == "true"
|
||||
# Where to store log file
|
||||
AUDIT_LOGS_FILE_PATH = f"{DATA_DIR}/audit.log"
|
||||
# Maximum size of a file before rotating into a new log file
|
||||
AUDIT_LOG_FILE_ROTATION_SIZE = os.getenv("AUDIT_LOG_FILE_ROTATION_SIZE", "10MB")
|
||||
# METADATA | REQUEST | REQUEST_RESPONSE
|
||||
AUDIT_LOG_LEVEL = os.getenv("AUDIT_LOG_LEVEL", "REQUEST_RESPONSE").upper()
|
||||
try:
|
||||
MAX_BODY_LOG_SIZE = int(os.environ.get("MAX_BODY_LOG_SIZE") or 2048)
|
||||
except ValueError:
|
||||
MAX_BODY_LOG_SIZE = 2048
|
||||
|
||||
# Comma separated list for urls to exclude from audit
|
||||
AUDIT_EXCLUDED_PATHS = os.getenv("AUDIT_EXCLUDED_PATHS", "/chats,/chat,/folders").split(
|
||||
","
|
||||
)
|
||||
AUDIT_EXCLUDED_PATHS = [path.strip() for path in AUDIT_EXCLUDED_PATHS]
|
||||
AUDIT_EXCLUDED_PATHS = [path.lstrip("/") for path in AUDIT_EXCLUDED_PATHS]
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import logging
|
|||
import sys
|
||||
import inspect
|
||||
import json
|
||||
import asyncio
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import AsyncGenerator, Generator, Iterator
|
||||
|
|
@ -76,11 +77,13 @@ async def get_function_models(request):
|
|||
if hasattr(function_module, "pipes"):
|
||||
sub_pipes = []
|
||||
|
||||
# Check if pipes is a function or a list
|
||||
|
||||
# Handle pipes being a list, sync function, or async function
|
||||
try:
|
||||
if callable(function_module.pipes):
|
||||
sub_pipes = function_module.pipes()
|
||||
if asyncio.iscoroutinefunction(function_module.pipes):
|
||||
sub_pipes = await function_module.pipes()
|
||||
else:
|
||||
sub_pipes = function_module.pipes()
|
||||
else:
|
||||
sub_pipes = function_module.pipes
|
||||
except Exception as e:
|
||||
|
|
@ -250,7 +253,7 @@ async def generate_function_chat_completion(
|
|||
|
||||
params = model_info.params.model_dump()
|
||||
form_data = apply_model_params_to_body_openai(params, form_data)
|
||||
form_data = apply_model_system_prompt_to_body(params, form_data, user)
|
||||
form_data = apply_model_system_prompt_to_body(params, form_data, metadata, user)
|
||||
|
||||
pipe_id = get_pipe_id(form_data)
|
||||
function_module = get_function_module_by_id(request, pipe_id)
|
||||
|
|
|
|||
|
|
@ -45,6 +45,9 @@ from starlette.middleware.sessions import SessionMiddleware
|
|||
from starlette.responses import Response, StreamingResponse
|
||||
|
||||
|
||||
from open_webui.utils import logger
|
||||
from open_webui.utils.audit import AuditLevel, AuditLoggingMiddleware
|
||||
from open_webui.utils.logger import start_logger
|
||||
from open_webui.socket.main import (
|
||||
app as socket_app,
|
||||
periodic_usage_pool_cleanup,
|
||||
|
|
@ -88,15 +91,34 @@ from open_webui.models.models import Models
|
|||
from open_webui.models.users import UserModel, Users
|
||||
|
||||
from open_webui.config import (
|
||||
LICENSE_KEY,
|
||||
# Ollama
|
||||
ENABLE_OLLAMA_API,
|
||||
OLLAMA_BASE_URLS,
|
||||
OLLAMA_API_CONFIGS,
|
||||
# OpenAI
|
||||
ENABLE_OPENAI_API,
|
||||
ONEDRIVE_CLIENT_ID,
|
||||
OPENAI_API_BASE_URLS,
|
||||
OPENAI_API_KEYS,
|
||||
OPENAI_API_CONFIGS,
|
||||
# Direct Connections
|
||||
ENABLE_DIRECT_CONNECTIONS,
|
||||
# Code Execution
|
||||
CODE_EXECUTION_ENGINE,
|
||||
CODE_EXECUTION_JUPYTER_URL,
|
||||
CODE_EXECUTION_JUPYTER_AUTH,
|
||||
CODE_EXECUTION_JUPYTER_AUTH_TOKEN,
|
||||
CODE_EXECUTION_JUPYTER_AUTH_PASSWORD,
|
||||
CODE_EXECUTION_JUPYTER_TIMEOUT,
|
||||
ENABLE_CODE_INTERPRETER,
|
||||
CODE_INTERPRETER_ENGINE,
|
||||
CODE_INTERPRETER_PROMPT_TEMPLATE,
|
||||
CODE_INTERPRETER_JUPYTER_URL,
|
||||
CODE_INTERPRETER_JUPYTER_AUTH,
|
||||
CODE_INTERPRETER_JUPYTER_AUTH_TOKEN,
|
||||
CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD,
|
||||
CODE_INTERPRETER_JUPYTER_TIMEOUT,
|
||||
# Image
|
||||
AUTOMATIC1111_API_AUTH,
|
||||
AUTOMATIC1111_BASE_URL,
|
||||
|
|
@ -115,6 +137,8 @@ from open_webui.config import (
|
|||
IMAGE_STEPS,
|
||||
IMAGES_OPENAI_API_BASE_URL,
|
||||
IMAGES_OPENAI_API_KEY,
|
||||
IMAGES_GEMINI_API_BASE_URL,
|
||||
IMAGES_GEMINI_API_KEY,
|
||||
# Audio
|
||||
AUDIO_STT_ENGINE,
|
||||
AUDIO_STT_MODEL,
|
||||
|
|
@ -129,12 +153,19 @@ from open_webui.config import (
|
|||
AUDIO_TTS_VOICE,
|
||||
AUDIO_TTS_AZURE_SPEECH_REGION,
|
||||
AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT,
|
||||
PLAYWRIGHT_WS_URI,
|
||||
FIRECRAWL_API_BASE_URL,
|
||||
FIRECRAWL_API_KEY,
|
||||
RAG_WEB_LOADER_ENGINE,
|
||||
WHISPER_MODEL,
|
||||
DEEPGRAM_API_KEY,
|
||||
WHISPER_MODEL_AUTO_UPDATE,
|
||||
WHISPER_MODEL_DIR,
|
||||
# Retrieval
|
||||
RAG_TEMPLATE,
|
||||
DEFAULT_RAG_TEMPLATE,
|
||||
RAG_FULL_CONTEXT,
|
||||
BYPASS_EMBEDDING_AND_RETRIEVAL,
|
||||
RAG_EMBEDDING_MODEL,
|
||||
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
|
||||
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
||||
|
|
@ -155,6 +186,8 @@ from open_webui.config import (
|
|||
CONTENT_EXTRACTION_ENGINE,
|
||||
TIKA_SERVER_URL,
|
||||
DOCLING_SERVER_URL,
|
||||
DOCUMENT_INTELLIGENCE_ENDPOINT,
|
||||
DOCUMENT_INTELLIGENCE_KEY,
|
||||
RAG_TOP_K,
|
||||
RAG_TEXT_SPLITTER,
|
||||
TIKTOKEN_ENCODING_NAME,
|
||||
|
|
@ -163,12 +196,16 @@ from open_webui.config import (
|
|||
YOUTUBE_LOADER_PROXY_URL,
|
||||
# Retrieval (Web Search)
|
||||
RAG_WEB_SEARCH_ENGINE,
|
||||
BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL,
|
||||
RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||
RAG_WEB_SEARCH_TRUST_ENV,
|
||||
RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
JINA_API_KEY,
|
||||
SEARCHAPI_API_KEY,
|
||||
SEARCHAPI_ENGINE,
|
||||
SERPAPI_API_KEY,
|
||||
SERPAPI_ENGINE,
|
||||
SEARXNG_QUERY_URL,
|
||||
SERPER_API_KEY,
|
||||
SERPLY_API_KEY,
|
||||
|
|
@ -178,17 +215,22 @@ from open_webui.config import (
|
|||
BING_SEARCH_V7_ENDPOINT,
|
||||
BING_SEARCH_V7_SUBSCRIPTION_KEY,
|
||||
BRAVE_SEARCH_API_KEY,
|
||||
EXA_API_KEY,
|
||||
PERPLEXITY_API_KEY,
|
||||
KAGI_SEARCH_API_KEY,
|
||||
MOJEEK_SEARCH_API_KEY,
|
||||
BOCHA_SEARCH_API_KEY,
|
||||
GOOGLE_PSE_API_KEY,
|
||||
GOOGLE_PSE_ENGINE_ID,
|
||||
GOOGLE_DRIVE_CLIENT_ID,
|
||||
GOOGLE_DRIVE_API_KEY,
|
||||
ONEDRIVE_CLIENT_ID,
|
||||
ENABLE_RAG_HYBRID_SEARCH,
|
||||
ENABLE_RAG_LOCAL_WEB_FETCH,
|
||||
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
||||
ENABLE_RAG_WEB_SEARCH,
|
||||
ENABLE_GOOGLE_DRIVE_INTEGRATION,
|
||||
ENABLE_ONEDRIVE_INTEGRATION,
|
||||
UPLOAD_DIR,
|
||||
# WebUI
|
||||
WEBUI_AUTH,
|
||||
|
|
@ -252,6 +294,7 @@ from open_webui.config import (
|
|||
TASK_MODEL,
|
||||
TASK_MODEL_EXTERNAL,
|
||||
ENABLE_TAGS_GENERATION,
|
||||
ENABLE_TITLE_GENERATION,
|
||||
ENABLE_SEARCH_QUERY_GENERATION,
|
||||
ENABLE_RETRIEVAL_QUERY_GENERATION,
|
||||
ENABLE_AUTOCOMPLETE_GENERATION,
|
||||
|
|
@ -266,8 +309,11 @@ from open_webui.config import (
|
|||
reset_config,
|
||||
)
|
||||
from open_webui.env import (
|
||||
AUDIT_EXCLUDED_PATHS,
|
||||
AUDIT_LOG_LEVEL,
|
||||
CHANGELOG,
|
||||
GLOBAL_LOG_LEVEL,
|
||||
MAX_BODY_LOG_SIZE,
|
||||
SAFE_MODE,
|
||||
SRC_LOG_LEVELS,
|
||||
VERSION,
|
||||
|
|
@ -298,15 +344,17 @@ from open_webui.utils.middleware import process_chat_payload, process_chat_respo
|
|||
from open_webui.utils.access_control import has_access
|
||||
|
||||
from open_webui.utils.auth import (
|
||||
get_license_data,
|
||||
decode_token,
|
||||
get_admin_user,
|
||||
get_verified_user,
|
||||
)
|
||||
from open_webui.utils.oauth import oauth_manager
|
||||
from open_webui.utils.oauth import OAuthManager
|
||||
from open_webui.utils.security_headers import SecurityHeadersMiddleware
|
||||
|
||||
from open_webui.tasks import stop_task, list_tasks # Import from tasks.py
|
||||
|
||||
|
||||
if SAFE_MODE:
|
||||
print("SAFE MODE ENABLED")
|
||||
Functions.deactivate_all_functions()
|
||||
|
|
@ -322,19 +370,23 @@ class SPAStaticFiles(StaticFiles):
|
|||
return await super().get_response(path, scope)
|
||||
except (HTTPException, StarletteHTTPException) as ex:
|
||||
if ex.status_code == 404:
|
||||
return await super().get_response("index.html", scope)
|
||||
if path.endswith(".js"):
|
||||
# Return 404 for javascript files
|
||||
raise ex
|
||||
else:
|
||||
return await super().get_response("index.html", scope)
|
||||
else:
|
||||
raise ex
|
||||
|
||||
|
||||
print(
|
||||
rf"""
|
||||
___ __ __ _ _ _ ___
|
||||
/ _ \ _ __ ___ _ __ \ \ / /__| |__ | | | |_ _|
|
||||
| | | | '_ \ / _ \ '_ \ \ \ /\ / / _ \ '_ \| | | || |
|
||||
| |_| | |_) | __/ | | | \ V V / __/ |_) | |_| || |
|
||||
\___/| .__/ \___|_| |_| \_/\_/ \___|_.__/ \___/|___|
|
||||
|_|
|
||||
██████╗ ██████╗ ███████╗███╗ ██╗ ██╗ ██╗███████╗██████╗ ██╗ ██╗██╗
|
||||
██╔═══██╗██╔══██╗██╔════╝████╗ ██║ ██║ ██║██╔════╝██╔══██╗██║ ██║██║
|
||||
██║ ██║██████╔╝█████╗ ██╔██╗ ██║ ██║ █╗ ██║█████╗ ██████╔╝██║ ██║██║
|
||||
██║ ██║██╔═══╝ ██╔══╝ ██║╚██╗██║ ██║███╗██║██╔══╝ ██╔══██╗██║ ██║██║
|
||||
╚██████╔╝██║ ███████╗██║ ╚████║ ╚███╔███╔╝███████╗██████╔╝╚██████╔╝██║
|
||||
╚═════╝ ╚═╝ ╚══════╝╚═╝ ╚═══╝ ╚══╝╚══╝ ╚══════╝╚═════╝ ╚═════╝ ╚═╝
|
||||
|
||||
|
||||
v{VERSION} - building the best open-source AI user interface.
|
||||
|
|
@ -346,9 +398,13 @@ https://github.com/open-webui/open-webui
|
|||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
start_logger()
|
||||
if RESET_CONFIG_ON_START:
|
||||
reset_config()
|
||||
|
||||
if LICENSE_KEY:
|
||||
get_license_data(app, LICENSE_KEY)
|
||||
|
||||
asyncio.create_task(periodic_usage_pool_cleanup())
|
||||
yield
|
||||
|
||||
|
|
@ -360,8 +416,12 @@ app = FastAPI(
|
|||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
oauth_manager = OAuthManager(app)
|
||||
|
||||
app.state.config = AppConfig()
|
||||
|
||||
app.state.WEBUI_NAME = WEBUI_NAME
|
||||
app.state.LICENSE_METADATA = None
|
||||
|
||||
########################################
|
||||
#
|
||||
|
|
@ -389,6 +449,14 @@ app.state.config.OPENAI_API_CONFIGS = OPENAI_API_CONFIGS
|
|||
|
||||
app.state.OPENAI_MODELS = {}
|
||||
|
||||
########################################
|
||||
#
|
||||
# DIRECT CONNECTIONS
|
||||
#
|
||||
########################################
|
||||
|
||||
app.state.config.ENABLE_DIRECT_CONNECTIONS = ENABLE_DIRECT_CONNECTIONS
|
||||
|
||||
########################################
|
||||
#
|
||||
# WEBUI
|
||||
|
|
@ -455,10 +523,10 @@ app.state.config.LDAP_CIPHERS = LDAP_CIPHERS
|
|||
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
|
||||
app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER
|
||||
|
||||
app.state.USER_COUNT = None
|
||||
app.state.TOOLS = {}
|
||||
app.state.FUNCTIONS = {}
|
||||
|
||||
|
||||
########################################
|
||||
#
|
||||
# RETRIEVAL
|
||||
|
|
@ -471,6 +539,9 @@ app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
|
|||
app.state.config.FILE_MAX_SIZE = RAG_FILE_MAX_SIZE
|
||||
app.state.config.FILE_MAX_COUNT = RAG_FILE_MAX_COUNT
|
||||
|
||||
|
||||
app.state.config.RAG_FULL_CONTEXT = RAG_FULL_CONTEXT
|
||||
app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL = BYPASS_EMBEDDING_AND_RETRIEVAL
|
||||
app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
|
||||
app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
|
||||
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
|
||||
|
|
@ -479,6 +550,8 @@ app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
|
|||
app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE
|
||||
app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL
|
||||
app.state.config.DOCLING_SERVER_URL = DOCLING_SERVER_URL
|
||||
app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = DOCUMENT_INTELLIGENCE_ENDPOINT
|
||||
app.state.config.DOCUMENT_INTELLIGENCE_KEY = DOCUMENT_INTELLIGENCE_KEY
|
||||
|
||||
app.state.config.TEXT_SPLITTER = RAG_TEXT_SPLITTER
|
||||
app.state.config.TIKTOKEN_ENCODING_NAME = TIKTOKEN_ENCODING_NAME
|
||||
|
|
@ -506,15 +579,20 @@ app.state.config.YOUTUBE_LOADER_PROXY_URL = YOUTUBE_LOADER_PROXY_URL
|
|||
|
||||
app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH
|
||||
app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE
|
||||
app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = (
|
||||
BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
|
||||
)
|
||||
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = RAG_WEB_SEARCH_DOMAIN_FILTER_LIST
|
||||
|
||||
app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION = ENABLE_GOOGLE_DRIVE_INTEGRATION
|
||||
app.state.config.ENABLE_ONEDRIVE_INTEGRATION = ENABLE_ONEDRIVE_INTEGRATION
|
||||
app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL
|
||||
app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY
|
||||
app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID
|
||||
app.state.config.BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY
|
||||
app.state.config.KAGI_SEARCH_API_KEY = KAGI_SEARCH_API_KEY
|
||||
app.state.config.MOJEEK_SEARCH_API_KEY = MOJEEK_SEARCH_API_KEY
|
||||
app.state.config.BOCHA_SEARCH_API_KEY = BOCHA_SEARCH_API_KEY
|
||||
app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY
|
||||
app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS
|
||||
app.state.config.SERPER_API_KEY = SERPER_API_KEY
|
||||
|
|
@ -522,12 +600,21 @@ app.state.config.SERPLY_API_KEY = SERPLY_API_KEY
|
|||
app.state.config.TAVILY_API_KEY = TAVILY_API_KEY
|
||||
app.state.config.SEARCHAPI_API_KEY = SEARCHAPI_API_KEY
|
||||
app.state.config.SEARCHAPI_ENGINE = SEARCHAPI_ENGINE
|
||||
app.state.config.SERPAPI_API_KEY = SERPAPI_API_KEY
|
||||
app.state.config.SERPAPI_ENGINE = SERPAPI_ENGINE
|
||||
app.state.config.JINA_API_KEY = JINA_API_KEY
|
||||
app.state.config.BING_SEARCH_V7_ENDPOINT = BING_SEARCH_V7_ENDPOINT
|
||||
app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = BING_SEARCH_V7_SUBSCRIPTION_KEY
|
||||
app.state.config.EXA_API_KEY = EXA_API_KEY
|
||||
app.state.config.PERPLEXITY_API_KEY = PERPLEXITY_API_KEY
|
||||
|
||||
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
|
||||
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
|
||||
app.state.config.RAG_WEB_LOADER_ENGINE = RAG_WEB_LOADER_ENGINE
|
||||
app.state.config.RAG_WEB_SEARCH_TRUST_ENV = RAG_WEB_SEARCH_TRUST_ENV
|
||||
app.state.config.PLAYWRIGHT_WS_URI = PLAYWRIGHT_WS_URI
|
||||
app.state.config.FIRECRAWL_API_BASE_URL = FIRECRAWL_API_BASE_URL
|
||||
app.state.config.FIRECRAWL_API_KEY = FIRECRAWL_API_KEY
|
||||
|
||||
app.state.EMBEDDING_FUNCTION = None
|
||||
app.state.ef = None
|
||||
|
|
@ -569,6 +656,34 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
|
|||
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||
)
|
||||
|
||||
########################################
|
||||
#
|
||||
# CODE EXECUTION
|
||||
#
|
||||
########################################
|
||||
|
||||
app.state.config.CODE_EXECUTION_ENGINE = CODE_EXECUTION_ENGINE
|
||||
app.state.config.CODE_EXECUTION_JUPYTER_URL = CODE_EXECUTION_JUPYTER_URL
|
||||
app.state.config.CODE_EXECUTION_JUPYTER_AUTH = CODE_EXECUTION_JUPYTER_AUTH
|
||||
app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN = CODE_EXECUTION_JUPYTER_AUTH_TOKEN
|
||||
app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD = (
|
||||
CODE_EXECUTION_JUPYTER_AUTH_PASSWORD
|
||||
)
|
||||
app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT = CODE_EXECUTION_JUPYTER_TIMEOUT
|
||||
|
||||
app.state.config.ENABLE_CODE_INTERPRETER = ENABLE_CODE_INTERPRETER
|
||||
app.state.config.CODE_INTERPRETER_ENGINE = CODE_INTERPRETER_ENGINE
|
||||
app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE = CODE_INTERPRETER_PROMPT_TEMPLATE
|
||||
|
||||
app.state.config.CODE_INTERPRETER_JUPYTER_URL = CODE_INTERPRETER_JUPYTER_URL
|
||||
app.state.config.CODE_INTERPRETER_JUPYTER_AUTH = CODE_INTERPRETER_JUPYTER_AUTH
|
||||
app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN = (
|
||||
CODE_INTERPRETER_JUPYTER_AUTH_TOKEN
|
||||
)
|
||||
app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = (
|
||||
CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD
|
||||
)
|
||||
app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT = CODE_INTERPRETER_JUPYTER_TIMEOUT
|
||||
|
||||
########################################
|
||||
#
|
||||
|
|
@ -583,6 +698,9 @@ app.state.config.ENABLE_IMAGE_PROMPT_GENERATION = ENABLE_IMAGE_PROMPT_GENERATION
|
|||
app.state.config.IMAGES_OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
|
||||
app.state.config.IMAGES_OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
|
||||
|
||||
app.state.config.IMAGES_GEMINI_API_BASE_URL = IMAGES_GEMINI_API_BASE_URL
|
||||
app.state.config.IMAGES_GEMINI_API_KEY = IMAGES_GEMINI_API_KEY
|
||||
|
||||
app.state.config.IMAGE_GENERATION_MODEL = IMAGE_GENERATION_MODEL
|
||||
|
||||
app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
|
||||
|
|
@ -611,6 +729,7 @@ app.state.config.STT_ENGINE = AUDIO_STT_ENGINE
|
|||
app.state.config.STT_MODEL = AUDIO_STT_MODEL
|
||||
|
||||
app.state.config.WHISPER_MODEL = WHISPER_MODEL
|
||||
app.state.config.DEEPGRAM_API_KEY = DEEPGRAM_API_KEY
|
||||
|
||||
app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL
|
||||
app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY
|
||||
|
|
@ -645,6 +764,7 @@ app.state.config.ENABLE_SEARCH_QUERY_GENERATION = ENABLE_SEARCH_QUERY_GENERATION
|
|||
app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = ENABLE_RETRIEVAL_QUERY_GENERATION
|
||||
app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = ENABLE_AUTOCOMPLETE_GENERATION
|
||||
app.state.config.ENABLE_TAGS_GENERATION = ENABLE_TAGS_GENERATION
|
||||
app.state.config.ENABLE_TITLE_GENERATION = ENABLE_TITLE_GENERATION
|
||||
|
||||
|
||||
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
|
|
@ -753,6 +873,7 @@ app.include_router(openai.router, prefix="/openai", tags=["openai"])
|
|||
app.include_router(pipelines.router, prefix="/api/v1/pipelines", tags=["pipelines"])
|
||||
app.include_router(tasks.router, prefix="/api/v1/tasks", tags=["tasks"])
|
||||
app.include_router(images.router, prefix="/api/v1/images", tags=["images"])
|
||||
|
||||
app.include_router(audio.router, prefix="/api/v1/audio", tags=["audio"])
|
||||
app.include_router(retrieval.router, prefix="/api/v1/retrieval", tags=["retrieval"])
|
||||
|
||||
|
|
@ -781,6 +902,19 @@ app.include_router(
|
|||
app.include_router(utils.router, prefix="/api/v1/utils", tags=["utils"])
|
||||
|
||||
|
||||
try:
|
||||
audit_level = AuditLevel(AUDIT_LOG_LEVEL)
|
||||
except ValueError as e:
|
||||
logger.error(f"Invalid audit level: {AUDIT_LOG_LEVEL}. Error: {e}")
|
||||
audit_level = AuditLevel.NONE
|
||||
|
||||
if audit_level != AuditLevel.NONE:
|
||||
app.add_middleware(
|
||||
AuditLoggingMiddleware,
|
||||
audit_level=audit_level,
|
||||
excluded_paths=AUDIT_EXCLUDED_PATHS,
|
||||
max_body_size=MAX_BODY_LOG_SIZE,
|
||||
)
|
||||
##################################
|
||||
#
|
||||
# Chat Endpoints
|
||||
|
|
@ -813,7 +947,7 @@ async def get_models(request: Request, user=Depends(get_verified_user)):
|
|||
|
||||
return filtered_models
|
||||
|
||||
models = await get_all_models(request)
|
||||
models = await get_all_models(request, user=user)
|
||||
|
||||
# Filter out filter pipelines
|
||||
models = [
|
||||
|
|
@ -842,7 +976,7 @@ async def get_models(request: Request, user=Depends(get_verified_user)):
|
|||
|
||||
@app.get("/api/models/base")
|
||||
async def get_base_models(request: Request, user=Depends(get_admin_user)):
|
||||
models = await get_all_base_models(request)
|
||||
models = await get_all_base_models(request, user=user)
|
||||
return {"data": models}
|
||||
|
||||
|
||||
|
|
@ -853,21 +987,32 @@ async def chat_completion(
|
|||
user=Depends(get_verified_user),
|
||||
):
|
||||
if not request.app.state.MODELS:
|
||||
await get_all_models(request)
|
||||
await get_all_models(request, user=user)
|
||||
|
||||
model_item = form_data.pop("model_item", {})
|
||||
tasks = form_data.pop("background_tasks", None)
|
||||
try:
|
||||
model_id = form_data.get("model", None)
|
||||
if model_id not in request.app.state.MODELS:
|
||||
raise Exception("Model not found")
|
||||
model = request.app.state.MODELS[model_id]
|
||||
|
||||
# Check if user has access to the model
|
||||
if not BYPASS_MODEL_ACCESS_CONTROL and user.role == "user":
|
||||
try:
|
||||
check_model_access(user, model)
|
||||
except Exception as e:
|
||||
raise e
|
||||
try:
|
||||
if not model_item.get("direct", False):
|
||||
model_id = form_data.get("model", None)
|
||||
if model_id not in request.app.state.MODELS:
|
||||
raise Exception("Model not found")
|
||||
|
||||
model = request.app.state.MODELS[model_id]
|
||||
model_info = Models.get_model_by_id(model_id)
|
||||
|
||||
# Check if user has access to the model
|
||||
if not BYPASS_MODEL_ACCESS_CONTROL and user.role == "user":
|
||||
try:
|
||||
check_model_access(user, model)
|
||||
except Exception as e:
|
||||
raise e
|
||||
else:
|
||||
model = model_item
|
||||
model_info = None
|
||||
|
||||
request.state.direct = True
|
||||
request.state.model = model
|
||||
|
||||
metadata = {
|
||||
"user_id": user.id,
|
||||
|
|
@ -877,13 +1022,30 @@ async def chat_completion(
|
|||
"tool_ids": form_data.get("tool_ids", None),
|
||||
"files": form_data.get("files", None),
|
||||
"features": form_data.get("features", None),
|
||||
"variables": form_data.get("variables", None),
|
||||
"model": model,
|
||||
"direct": model_item.get("direct", False),
|
||||
**(
|
||||
{"function_calling": "native"}
|
||||
if form_data.get("params", {}).get("function_calling") == "native"
|
||||
or (
|
||||
model_info
|
||||
and model_info.params.model_dump().get("function_calling")
|
||||
== "native"
|
||||
)
|
||||
else {}
|
||||
),
|
||||
}
|
||||
|
||||
request.state.metadata = metadata
|
||||
form_data["metadata"] = metadata
|
||||
|
||||
form_data, events = await process_chat_payload(
|
||||
request, form_data, metadata, user, model
|
||||
form_data, metadata, events = await process_chat_payload(
|
||||
request, form_data, user, metadata, model
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
log.debug(f"Error processing chat payload: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
|
|
@ -891,8 +1053,9 @@ async def chat_completion(
|
|||
|
||||
try:
|
||||
response = await chat_completion_handler(request, form_data, user)
|
||||
|
||||
return await process_chat_response(
|
||||
request, response, form_data, user, events, metadata, tasks
|
||||
request, response, form_data, user, metadata, model, events, tasks
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
|
|
@ -911,6 +1074,12 @@ async def chat_completed(
|
|||
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||
):
|
||||
try:
|
||||
model_item = form_data.pop("model_item", {})
|
||||
|
||||
if model_item.get("direct", False):
|
||||
request.state.direct = True
|
||||
request.state.model = model_item
|
||||
|
||||
return await chat_completed_handler(request, form_data, user)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
|
|
@ -924,6 +1093,12 @@ async def chat_action(
|
|||
request: Request, action_id: str, form_data: dict, user=Depends(get_verified_user)
|
||||
):
|
||||
try:
|
||||
model_item = form_data.pop("model_item", {})
|
||||
|
||||
if model_item.get("direct", False):
|
||||
request.state.direct = True
|
||||
request.state.model = model_item
|
||||
|
||||
return await chat_action_handler(request, action_id, form_data, user)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
|
|
@ -969,15 +1144,16 @@ async def get_app_config(request: Request):
|
|||
if data is not None and "id" in data:
|
||||
user = Users.get_user_by_id(data["id"])
|
||||
|
||||
user_count = Users.get_num_users()
|
||||
onboarding = False
|
||||
|
||||
if user is None:
|
||||
user_count = Users.get_num_users()
|
||||
onboarding = user_count == 0
|
||||
|
||||
return {
|
||||
**({"onboarding": True} if onboarding else {}),
|
||||
"status": True,
|
||||
"name": WEBUI_NAME,
|
||||
"name": app.state.WEBUI_NAME,
|
||||
"version": VERSION,
|
||||
"default_locale": str(DEFAULT_LOCALE),
|
||||
"oauth": {
|
||||
|
|
@ -996,27 +1172,31 @@ async def get_app_config(request: Request):
|
|||
"enable_websocket": ENABLE_WEBSOCKET_SUPPORT,
|
||||
**(
|
||||
{
|
||||
"enable_direct_connections": app.state.config.ENABLE_DIRECT_CONNECTIONS,
|
||||
"enable_channels": app.state.config.ENABLE_CHANNELS,
|
||||
"enable_web_search": app.state.config.ENABLE_RAG_WEB_SEARCH,
|
||||
"enable_google_drive_integration": app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
|
||||
"enable_code_interpreter": app.state.config.ENABLE_CODE_INTERPRETER,
|
||||
"enable_image_generation": app.state.config.ENABLE_IMAGE_GENERATION,
|
||||
"enable_autocomplete_generation": app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
|
||||
"enable_community_sharing": app.state.config.ENABLE_COMMUNITY_SHARING,
|
||||
"enable_message_rating": app.state.config.ENABLE_MESSAGE_RATING,
|
||||
"enable_admin_export": ENABLE_ADMIN_EXPORT,
|
||||
"enable_admin_chat_access": ENABLE_ADMIN_CHAT_ACCESS,
|
||||
"enable_google_drive_integration": app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
|
||||
"enable_onedrive_integration": app.state.config.ENABLE_ONEDRIVE_INTEGRATION,
|
||||
}
|
||||
if user is not None
|
||||
else {}
|
||||
),
|
||||
},
|
||||
"google_drive": {
|
||||
"client_id": GOOGLE_DRIVE_CLIENT_ID.value,
|
||||
"api_key": GOOGLE_DRIVE_API_KEY.value,
|
||||
},
|
||||
**(
|
||||
{
|
||||
"default_models": app.state.config.DEFAULT_MODELS,
|
||||
"default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
|
||||
"user_count": user_count,
|
||||
"code": {
|
||||
"engine": app.state.config.CODE_EXECUTION_ENGINE,
|
||||
},
|
||||
"audio": {
|
||||
"tts": {
|
||||
"engine": app.state.config.TTS_ENGINE,
|
||||
|
|
@ -1032,6 +1212,19 @@ async def get_app_config(request: Request):
|
|||
"max_count": app.state.config.FILE_MAX_COUNT,
|
||||
},
|
||||
"permissions": {**app.state.config.USER_PERMISSIONS},
|
||||
"google_drive": {
|
||||
"client_id": GOOGLE_DRIVE_CLIENT_ID.value,
|
||||
"api_key": GOOGLE_DRIVE_API_KEY.value,
|
||||
},
|
||||
"onedrive": {"client_id": ONEDRIVE_CLIENT_ID.value},
|
||||
"license_metadata": app.state.LICENSE_METADATA,
|
||||
**(
|
||||
{
|
||||
"active_entries": app.state.USER_COUNT,
|
||||
}
|
||||
if user.role == "admin"
|
||||
else {}
|
||||
),
|
||||
}
|
||||
if user is not None
|
||||
else {}
|
||||
|
|
@ -1065,7 +1258,7 @@ async def get_app_version():
|
|||
|
||||
|
||||
@app.get("/api/version/updates")
|
||||
async def get_app_latest_release_version():
|
||||
async def get_app_latest_release_version(user=Depends(get_verified_user)):
|
||||
if OFFLINE_MODE:
|
||||
log.debug(
|
||||
f"Offline mode is enabled, returning current version as latest version"
|
||||
|
|
@ -1109,7 +1302,7 @@ if len(OAUTH_PROVIDERS) > 0:
|
|||
|
||||
@app.get("/oauth/{provider}/login")
|
||||
async def oauth_login(provider: str, request: Request):
|
||||
return await oauth_manager.handle_login(provider, request)
|
||||
return await oauth_manager.handle_login(request, provider)
|
||||
|
||||
|
||||
# OAuth login logic is as follows:
|
||||
|
|
@ -1120,14 +1313,14 @@ async def oauth_login(provider: str, request: Request):
|
|||
# - Email addresses are considered unique, so we fail registration if the email address is already taken
|
||||
@app.get("/oauth/{provider}/callback")
|
||||
async def oauth_callback(provider: str, request: Request, response: Response):
|
||||
return await oauth_manager.handle_callback(provider, request, response)
|
||||
return await oauth_manager.handle_callback(request, provider, response)
|
||||
|
||||
|
||||
@app.get("/manifest.json")
|
||||
async def get_manifest_json():
|
||||
return {
|
||||
"name": WEBUI_NAME,
|
||||
"short_name": WEBUI_NAME,
|
||||
"name": app.state.WEBUI_NAME,
|
||||
"short_name": app.state.WEBUI_NAME,
|
||||
"description": "Open WebUI is an open, extensible, user-friendly interface for AI that adapts to your workflow.",
|
||||
"start_url": "/",
|
||||
"display": "standalone",
|
||||
|
|
@ -1154,8 +1347,8 @@ async def get_manifest_json():
|
|||
async def get_opensearch_xml():
|
||||
xml_content = rf"""
|
||||
<OpenSearchDescription xmlns="http://a9.com/-/spec/opensearch/1.1/" xmlns:moz="http://www.mozilla.org/2006/browser/search/">
|
||||
<ShortName>{WEBUI_NAME}</ShortName>
|
||||
<Description>Search {WEBUI_NAME}</Description>
|
||||
<ShortName>{app.state.WEBUI_NAME}</ShortName>
|
||||
<Description>Search {app.state.WEBUI_NAME}</Description>
|
||||
<InputEncoding>UTF-8</InputEncoding>
|
||||
<Image width="16" height="16" type="image/x-icon">{app.state.config.WEBUI_URL}/static/favicon.png</Image>
|
||||
<Url type="text/html" method="get" template="{app.state.config.WEBUI_URL}/?q={"{searchTerms}"}"/>
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
|
|
@ -5,7 +6,7 @@ from typing import Optional
|
|||
|
||||
from open_webui.internal.db import Base, get_db
|
||||
from open_webui.models.tags import TagModel, Tag, Tags
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
|
||||
|
|
@ -16,6 +17,9 @@ from sqlalchemy.sql import exists
|
|||
# Chat DB Schema
|
||||
####################
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
|
||||
|
||||
class Chat(Base):
|
||||
__tablename__ = "chat"
|
||||
|
|
@ -470,7 +474,7 @@ class ChatTable:
|
|||
try:
|
||||
with get_db() as db:
|
||||
# it is possible that the shared link was deleted. hence,
|
||||
# we check if the chat is still shared by checkng if a chat with the share_id exists
|
||||
# we check if the chat is still shared by checking if a chat with the share_id exists
|
||||
chat = db.query(Chat).filter_by(share_id=id).first()
|
||||
|
||||
if chat:
|
||||
|
|
@ -670,7 +674,7 @@ class ChatTable:
|
|||
# Perform pagination at the SQL level
|
||||
all_chats = query.offset(skip).limit(limit).all()
|
||||
|
||||
print(len(all_chats))
|
||||
log.info(f"The number of chats: {len(all_chats)}")
|
||||
|
||||
# Validate and return chats
|
||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
|
|
@ -731,7 +735,7 @@ class ChatTable:
|
|||
query = db.query(Chat).filter_by(user_id=user_id)
|
||||
tag_id = tag_name.replace(" ", "_").lower()
|
||||
|
||||
print(db.bind.dialect.name)
|
||||
log.info(f"DB dialect name: {db.bind.dialect.name}")
|
||||
if db.bind.dialect.name == "sqlite":
|
||||
# SQLite JSON1 querying for tags within the meta JSON field
|
||||
query = query.filter(
|
||||
|
|
@ -752,7 +756,7 @@ class ChatTable:
|
|||
)
|
||||
|
||||
all_chats = query.all()
|
||||
print("all_chats", all_chats)
|
||||
log.debug(f"all_chats: {all_chats}")
|
||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
|
||||
def add_chat_tag_by_id_and_user_id_and_tag_name(
|
||||
|
|
@ -810,7 +814,7 @@ class ChatTable:
|
|||
count = query.count()
|
||||
|
||||
# Debugging output for inspection
|
||||
print(f"Count of chats for tag '{tag_name}':", count)
|
||||
log.info(f"Count of chats for tag '{tag_name}': {count}")
|
||||
|
||||
return count
|
||||
|
||||
|
|
|
|||
|
|
@ -118,7 +118,7 @@ class FeedbackTable:
|
|||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Error creating a new feedback: {e}")
|
||||
return None
|
||||
|
||||
def get_feedback_by_id(self, id: str) -> Optional[FeedbackModel]:
|
||||
|
|
|
|||
|
|
@ -119,7 +119,7 @@ class FilesTable:
|
|||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Error creating tool: {e}")
|
||||
log.exception(f"Error inserting a new file: {e}")
|
||||
return None
|
||||
|
||||
def get_file_by_id(self, id: str) -> Optional[FileModel]:
|
||||
|
|
|
|||
|
|
@ -82,7 +82,7 @@ class FolderTable:
|
|||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Error inserting a new folder: {e}")
|
||||
return None
|
||||
|
||||
def get_folder_by_id_and_user_id(
|
||||
|
|
|
|||
|
|
@ -105,7 +105,7 @@ class FunctionsTable:
|
|||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Error creating tool: {e}")
|
||||
log.exception(f"Error creating a new function: {e}")
|
||||
return None
|
||||
|
||||
def get_function_by_id(self, id: str) -> Optional[FunctionModel]:
|
||||
|
|
@ -170,7 +170,7 @@ class FunctionsTable:
|
|||
function = db.get(Function, id)
|
||||
return function.valves if function.valves else {}
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
log.exception(f"Error getting function valves by id {id}: {e}")
|
||||
return None
|
||||
|
||||
def update_function_valves_by_id(
|
||||
|
|
@ -202,7 +202,9 @@ class FunctionsTable:
|
|||
|
||||
return user_settings["functions"]["valves"].get(id, {})
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
log.exception(
|
||||
f"Error getting user values by id {id} and user id {user_id}: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
def update_user_valves_by_id_and_user_id(
|
||||
|
|
@ -225,7 +227,9 @@ class FunctionsTable:
|
|||
|
||||
return user_settings["functions"]["valves"][id]
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
log.exception(
|
||||
f"Error updating user valves by id {id} and user_id {user_id}: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]:
|
||||
|
|
|
|||
|
|
@ -166,7 +166,7 @@ class ModelsTable:
|
|||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Failed to insert a new model: {e}")
|
||||
return None
|
||||
|
||||
def get_all_models(self) -> list[ModelModel]:
|
||||
|
|
@ -246,8 +246,7 @@ class ModelsTable:
|
|||
db.refresh(model)
|
||||
return ModelModel.model_validate(model)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
log.exception(f"Failed to update the model by id {id}: {e}")
|
||||
return None
|
||||
|
||||
def delete_model_by_id(self, id: str) -> bool:
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ class TagTable:
|
|||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Error inserting a new tag: {e}")
|
||||
return None
|
||||
|
||||
def get_tag_by_name_and_user_id(
|
||||
|
|
|
|||
|
|
@ -131,7 +131,7 @@ class ToolsTable:
|
|||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Error creating tool: {e}")
|
||||
log.exception(f"Error creating a new tool: {e}")
|
||||
return None
|
||||
|
||||
def get_tool_by_id(self, id: str) -> Optional[ToolModel]:
|
||||
|
|
@ -175,7 +175,7 @@ class ToolsTable:
|
|||
tool = db.get(Tool, id)
|
||||
return tool.valves if tool.valves else {}
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
log.exception(f"Error getting tool valves by id {id}: {e}")
|
||||
return None
|
||||
|
||||
def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]:
|
||||
|
|
@ -204,7 +204,9 @@ class ToolsTable:
|
|||
|
||||
return user_settings["tools"]["valves"].get(id, {})
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
log.exception(
|
||||
f"Error getting user values by id {id} and user_id {user_id}: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
def update_user_valves_by_id_and_user_id(
|
||||
|
|
@ -227,7 +229,9 @@ class ToolsTable:
|
|||
|
||||
return user_settings["tools"]["valves"][id]
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
log.exception(
|
||||
f"Error updating user valves by id {id} and user_id {user_id}: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:
|
||||
|
|
|
|||
|
|
@ -271,6 +271,24 @@ class UsersTable:
|
|||
except Exception:
|
||||
return None
|
||||
|
||||
def update_user_settings_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
user_settings = db.query(User).filter_by(id=id).first().settings
|
||||
|
||||
if user_settings is None:
|
||||
user_settings = {}
|
||||
|
||||
user_settings.update(updated)
|
||||
|
||||
db.query(User).filter_by(id=id).update({"settings": user_settings})
|
||||
db.commit()
|
||||
|
||||
user = db.query(User).filter_by(id=id).first()
|
||||
return UserModel.model_validate(user)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def delete_user_by_id(self, id: str) -> bool:
|
||||
try:
|
||||
# Remove User from Groups
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import ftfy
|
|||
import sys
|
||||
|
||||
from langchain_community.document_loaders import (
|
||||
AzureAIDocumentIntelligenceLoader,
|
||||
BSHTMLLoader,
|
||||
CSVLoader,
|
||||
Docx2txtLoader,
|
||||
|
|
@ -76,6 +77,7 @@ known_source_ext = [
|
|||
"jsx",
|
||||
"hs",
|
||||
"lhs",
|
||||
"json",
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -221,12 +223,33 @@ class Loader:
|
|||
file_path=file_path,
|
||||
mime_type=file_content_type,
|
||||
)
|
||||
elif self.engine == "docling":
|
||||
elif self.engine == "docling" and self.kwargs.get("DOCLING_SERVER_URL"):
|
||||
loader = DoclingLoader(
|
||||
url=self.kwargs.get("DOCLING_SERVER_URL"),
|
||||
file_path=file_path,
|
||||
mime_type=file_content_type,
|
||||
)
|
||||
elif (
|
||||
self.engine == "document_intelligence"
|
||||
and self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT") != ""
|
||||
and self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY") != ""
|
||||
and (
|
||||
file_ext in ["pdf", "xls", "xlsx", "docx", "ppt", "pptx"]
|
||||
or file_content_type
|
||||
in [
|
||||
"application/vnd.ms-excel",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"application/vnd.ms-powerpoint",
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||
]
|
||||
)
|
||||
):
|
||||
loader = AzureAIDocumentIntelligenceLoader(
|
||||
file_path=file_path,
|
||||
api_endpoint=self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT"),
|
||||
api_key=self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY"),
|
||||
)
|
||||
else:
|
||||
if file_ext == "pdf":
|
||||
loader = PyPDFLoader(
|
||||
|
|
|
|||
|
|
@ -1,13 +1,19 @@
|
|||
import os
|
||||
import logging
|
||||
import torch
|
||||
import numpy as np
|
||||
from colbert.infra import ColBERTConfig
|
||||
from colbert.modeling.checkpoint import Checkpoint
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
class ColBERT:
|
||||
def __init__(self, name, **kwargs) -> None:
|
||||
print("ColBERT: Loading model", name)
|
||||
log.info("ColBERT: Loading model", name)
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
DOCKER = kwargs.get("env") == "docker"
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from typing import Optional, Union
|
|||
|
||||
import asyncio
|
||||
import requests
|
||||
import hashlib
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
|
||||
|
|
@ -14,9 +15,16 @@ from langchain_core.documents import Document
|
|||
|
||||
from open_webui.config import VECTOR_DB
|
||||
from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
|
||||
from open_webui.utils.misc import get_last_user_message
|
||||
from open_webui.utils.misc import get_last_user_message, calculate_sha256_string
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS, OFFLINE_MODE
|
||||
from open_webui.models.users import UserModel
|
||||
from open_webui.models.files import Files
|
||||
|
||||
from open_webui.env import (
|
||||
SRC_LOG_LEVELS,
|
||||
OFFLINE_MODE,
|
||||
ENABLE_FORWARD_USER_INFO_HEADERS,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
|
@ -61,9 +69,7 @@ class VectorSearchRetriever(BaseRetriever):
|
|||
|
||||
|
||||
def query_doc(
|
||||
collection_name: str,
|
||||
query_embedding: list[float],
|
||||
k: int,
|
||||
collection_name: str, query_embedding: list[float], k: int, user: UserModel = None
|
||||
):
|
||||
try:
|
||||
result = VECTOR_DB_CLIENT.search(
|
||||
|
|
@ -77,7 +83,20 @@ def query_doc(
|
|||
|
||||
return result
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Error querying doc {collection_name} with limit {k}: {e}")
|
||||
raise e
|
||||
|
||||
|
||||
def get_doc(collection_name: str, user: UserModel = None):
|
||||
try:
|
||||
result = VECTOR_DB_CLIENT.get(collection_name=collection_name)
|
||||
|
||||
if result:
|
||||
log.info(f"query_doc:result {result.ids} {result.metadatas}")
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
log.exception(f"Error getting doc {collection_name}: {e}")
|
||||
raise e
|
||||
|
||||
|
||||
|
|
@ -134,47 +153,80 @@ def query_doc_with_hybrid_search(
|
|||
raise e
|
||||
|
||||
|
||||
def merge_and_sort_query_results(
|
||||
query_results: list[dict], k: int, reverse: bool = False
|
||||
) -> list[dict]:
|
||||
def merge_get_results(get_results: list[dict]) -> dict:
|
||||
# Initialize lists to store combined data
|
||||
combined_distances = []
|
||||
combined_documents = []
|
||||
combined_metadatas = []
|
||||
combined_ids = []
|
||||
|
||||
for data in query_results:
|
||||
combined_distances.extend(data["distances"][0])
|
||||
for data in get_results:
|
||||
combined_documents.extend(data["documents"][0])
|
||||
combined_metadatas.extend(data["metadatas"][0])
|
||||
combined_ids.extend(data["ids"][0])
|
||||
|
||||
# Create a list of tuples (distance, document, metadata)
|
||||
combined = list(zip(combined_distances, combined_documents, combined_metadatas))
|
||||
# Create the output dictionary
|
||||
result = {
|
||||
"documents": [combined_documents],
|
||||
"metadatas": [combined_metadatas],
|
||||
"ids": [combined_ids],
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def merge_and_sort_query_results(
|
||||
query_results: list[dict], k: int, reverse: bool = False
|
||||
) -> dict:
|
||||
# Initialize lists to store combined data
|
||||
combined = []
|
||||
seen_hashes = set() # To store unique document hashes
|
||||
|
||||
for data in query_results:
|
||||
distances = data["distances"][0]
|
||||
documents = data["documents"][0]
|
||||
metadatas = data["metadatas"][0]
|
||||
|
||||
for distance, document, metadata in zip(distances, documents, metadatas):
|
||||
if isinstance(document, str):
|
||||
doc_hash = hashlib.md5(
|
||||
document.encode()
|
||||
).hexdigest() # Compute a hash for uniqueness
|
||||
|
||||
if doc_hash not in seen_hashes:
|
||||
seen_hashes.add(doc_hash)
|
||||
combined.append((distance, document, metadata))
|
||||
|
||||
# Sort the list based on distances
|
||||
combined.sort(key=lambda x: x[0], reverse=reverse)
|
||||
|
||||
# We don't have anything :-(
|
||||
if not combined:
|
||||
sorted_distances = []
|
||||
sorted_documents = []
|
||||
sorted_metadatas = []
|
||||
else:
|
||||
# Unzip the sorted list
|
||||
sorted_distances, sorted_documents, sorted_metadatas = zip(*combined)
|
||||
# Slice to keep only the top k elements
|
||||
sorted_distances, sorted_documents, sorted_metadatas = (
|
||||
zip(*combined[:k]) if combined else ([], [], [])
|
||||
)
|
||||
|
||||
# Slicing the lists to include only k elements
|
||||
sorted_distances = list(sorted_distances)[:k]
|
||||
sorted_documents = list(sorted_documents)[:k]
|
||||
sorted_metadatas = list(sorted_metadatas)[:k]
|
||||
|
||||
# Create the output dictionary
|
||||
result = {
|
||||
"distances": [sorted_distances],
|
||||
"documents": [sorted_documents],
|
||||
"metadatas": [sorted_metadatas],
|
||||
# Create and return the output dictionary
|
||||
return {
|
||||
"distances": [list(sorted_distances)],
|
||||
"documents": [list(sorted_documents)],
|
||||
"metadatas": [list(sorted_metadatas)],
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def get_all_items_from_collections(collection_names: list[str]) -> dict:
|
||||
results = []
|
||||
|
||||
for collection_name in collection_names:
|
||||
if collection_name:
|
||||
try:
|
||||
result = get_doc(collection_name=collection_name)
|
||||
if result is not None:
|
||||
results.append(result.model_dump())
|
||||
except Exception as e:
|
||||
log.exception(f"Error when querying the collection: {e}")
|
||||
else:
|
||||
pass
|
||||
|
||||
return merge_get_results(results)
|
||||
|
||||
|
||||
def query_collection(
|
||||
|
|
@ -259,29 +311,35 @@ def get_embedding_function(
|
|||
embedding_batch_size,
|
||||
):
|
||||
if embedding_engine == "":
|
||||
return lambda query: embedding_function.encode(query).tolist()
|
||||
return lambda query, user=None: embedding_function.encode(query).tolist()
|
||||
elif embedding_engine in ["ollama", "openai"]:
|
||||
func = lambda query: generate_embeddings(
|
||||
func = lambda query, user=None: generate_embeddings(
|
||||
engine=embedding_engine,
|
||||
model=embedding_model,
|
||||
text=query,
|
||||
url=url,
|
||||
key=key,
|
||||
user=user,
|
||||
)
|
||||
|
||||
def generate_multiple(query, func):
|
||||
def generate_multiple(query, user, func):
|
||||
if isinstance(query, list):
|
||||
embeddings = []
|
||||
for i in range(0, len(query), embedding_batch_size):
|
||||
embeddings.extend(func(query[i : i + embedding_batch_size]))
|
||||
embeddings.extend(
|
||||
func(query[i : i + embedding_batch_size], user=user)
|
||||
)
|
||||
return embeddings
|
||||
else:
|
||||
return func(query)
|
||||
return func(query, user)
|
||||
|
||||
return lambda query: generate_multiple(query, func)
|
||||
return lambda query, user=None: generate_multiple(query, user, func)
|
||||
else:
|
||||
raise ValueError(f"Unknown embedding engine: {embedding_engine}")
|
||||
|
||||
|
||||
def get_sources_from_files(
|
||||
request,
|
||||
files,
|
||||
queries,
|
||||
embedding_function,
|
||||
|
|
@ -289,21 +347,81 @@ def get_sources_from_files(
|
|||
reranking_function,
|
||||
r,
|
||||
hybrid_search,
|
||||
full_context=False,
|
||||
):
|
||||
log.debug(f"files: {files} {queries} {embedding_function} {reranking_function}")
|
||||
log.debug(
|
||||
f"files: {files} {queries} {embedding_function} {reranking_function} {full_context}"
|
||||
)
|
||||
|
||||
extracted_collections = []
|
||||
relevant_contexts = []
|
||||
|
||||
for file in files:
|
||||
if file.get("context") == "full":
|
||||
|
||||
context = None
|
||||
if file.get("docs"):
|
||||
# BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
|
||||
context = {
|
||||
"documents": [[doc.get("content") for doc in file.get("docs")]],
|
||||
"metadatas": [[doc.get("metadata") for doc in file.get("docs")]],
|
||||
}
|
||||
elif file.get("context") == "full":
|
||||
# Manual Full Mode Toggle
|
||||
context = {
|
||||
"documents": [[file.get("file").get("data", {}).get("content")]],
|
||||
"metadatas": [[{"file_id": file.get("id"), "name": file.get("name")}]],
|
||||
}
|
||||
else:
|
||||
context = None
|
||||
elif (
|
||||
file.get("type") != "web_search"
|
||||
and request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
|
||||
):
|
||||
# BYPASS_EMBEDDING_AND_RETRIEVAL
|
||||
if file.get("type") == "collection":
|
||||
file_ids = file.get("data", {}).get("file_ids", [])
|
||||
|
||||
documents = []
|
||||
metadatas = []
|
||||
for file_id in file_ids:
|
||||
file_object = Files.get_file_by_id(file_id)
|
||||
|
||||
if file_object:
|
||||
documents.append(file_object.data.get("content", ""))
|
||||
metadatas.append(
|
||||
{
|
||||
"file_id": file_id,
|
||||
"name": file_object.filename,
|
||||
"source": file_object.filename,
|
||||
}
|
||||
)
|
||||
|
||||
context = {
|
||||
"documents": [documents],
|
||||
"metadatas": [metadatas],
|
||||
}
|
||||
|
||||
elif file.get("id"):
|
||||
file_object = Files.get_file_by_id(file.get("id"))
|
||||
if file_object:
|
||||
context = {
|
||||
"documents": [[file_object.data.get("content", "")]],
|
||||
"metadatas": [
|
||||
[
|
||||
{
|
||||
"file_id": file.get("id"),
|
||||
"name": file_object.filename,
|
||||
"source": file_object.filename,
|
||||
}
|
||||
]
|
||||
],
|
||||
}
|
||||
elif file.get("file").get("data"):
|
||||
context = {
|
||||
"documents": [[file.get("file").get("data", {}).get("content")]],
|
||||
"metadatas": [
|
||||
[file.get("file").get("data", {}).get("metadata", {})]
|
||||
],
|
||||
}
|
||||
else:
|
||||
collection_names = []
|
||||
if file.get("type") == "collection":
|
||||
if file.get("legacy"):
|
||||
|
|
@ -323,42 +441,50 @@ def get_sources_from_files(
|
|||
log.debug(f"skipping {file} as it has already been extracted")
|
||||
continue
|
||||
|
||||
try:
|
||||
context = None
|
||||
if file.get("type") == "text":
|
||||
context = file["content"]
|
||||
else:
|
||||
if hybrid_search:
|
||||
try:
|
||||
context = query_collection_with_hybrid_search(
|
||||
if full_context:
|
||||
try:
|
||||
context = get_all_items_from_collections(collection_names)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
else:
|
||||
try:
|
||||
context = None
|
||||
if file.get("type") == "text":
|
||||
context = file["content"]
|
||||
else:
|
||||
if hybrid_search:
|
||||
try:
|
||||
context = query_collection_with_hybrid_search(
|
||||
collection_names=collection_names,
|
||||
queries=queries,
|
||||
embedding_function=embedding_function,
|
||||
k=k,
|
||||
reranking_function=reranking_function,
|
||||
r=r,
|
||||
)
|
||||
except Exception as e:
|
||||
log.debug(
|
||||
"Error when using hybrid search, using"
|
||||
" non hybrid search as fallback."
|
||||
)
|
||||
|
||||
if (not hybrid_search) or (context is None):
|
||||
context = query_collection(
|
||||
collection_names=collection_names,
|
||||
queries=queries,
|
||||
embedding_function=embedding_function,
|
||||
k=k,
|
||||
reranking_function=reranking_function,
|
||||
r=r,
|
||||
)
|
||||
except Exception as e:
|
||||
log.debug(
|
||||
"Error when using hybrid search, using"
|
||||
" non hybrid search as fallback."
|
||||
)
|
||||
|
||||
if (not hybrid_search) or (context is None):
|
||||
context = query_collection(
|
||||
collection_names=collection_names,
|
||||
queries=queries,
|
||||
embedding_function=embedding_function,
|
||||
k=k,
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
extracted_collections.extend(collection_names)
|
||||
|
||||
if context:
|
||||
if "data" in file:
|
||||
del file["data"]
|
||||
|
||||
relevant_contexts.append({**context, "file": file})
|
||||
|
||||
sources = []
|
||||
|
|
@ -423,7 +549,11 @@ def get_model_path(model: str, update_model: bool = False):
|
|||
|
||||
|
||||
def generate_openai_batch_embeddings(
|
||||
model: str, texts: list[str], url: str = "https://api.openai.com/v1", key: str = ""
|
||||
model: str,
|
||||
texts: list[str],
|
||||
url: str = "https://api.openai.com/v1",
|
||||
key: str = "",
|
||||
user: UserModel = None,
|
||||
) -> Optional[list[list[float]]]:
|
||||
try:
|
||||
r = requests.post(
|
||||
|
|
@ -431,6 +561,16 @@ def generate_openai_batch_embeddings(
|
|||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {key}",
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
json={"input": texts, "model": model},
|
||||
)
|
||||
|
|
@ -441,12 +581,12 @@ def generate_openai_batch_embeddings(
|
|||
else:
|
||||
raise "Something went wrong :/"
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Error generating openai batch embeddings: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def generate_ollama_batch_embeddings(
|
||||
model: str, texts: list[str], url: str, key: str = ""
|
||||
model: str, texts: list[str], url: str, key: str = "", user: UserModel = None
|
||||
) -> Optional[list[list[float]]]:
|
||||
try:
|
||||
r = requests.post(
|
||||
|
|
@ -454,6 +594,16 @@ def generate_ollama_batch_embeddings(
|
|||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {key}",
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS
|
||||
else {}
|
||||
),
|
||||
},
|
||||
json={"input": texts, "model": model},
|
||||
)
|
||||
|
|
@ -465,29 +615,36 @@ def generate_ollama_batch_embeddings(
|
|||
else:
|
||||
raise "Something went wrong :/"
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Error generating ollama batch embeddings: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):
|
||||
url = kwargs.get("url", "")
|
||||
key = kwargs.get("key", "")
|
||||
user = kwargs.get("user")
|
||||
|
||||
if engine == "ollama":
|
||||
if isinstance(text, list):
|
||||
embeddings = generate_ollama_batch_embeddings(
|
||||
**{"model": model, "texts": text, "url": url, "key": key}
|
||||
**{"model": model, "texts": text, "url": url, "key": key, "user": user}
|
||||
)
|
||||
else:
|
||||
embeddings = generate_ollama_batch_embeddings(
|
||||
**{"model": model, "texts": [text], "url": url, "key": key}
|
||||
**{
|
||||
"model": model,
|
||||
"texts": [text],
|
||||
"url": url,
|
||||
"key": key,
|
||||
"user": user,
|
||||
}
|
||||
)
|
||||
return embeddings[0] if isinstance(text, str) else embeddings
|
||||
elif engine == "openai":
|
||||
if isinstance(text, list):
|
||||
embeddings = generate_openai_batch_embeddings(model, text, url, key)
|
||||
embeddings = generate_openai_batch_embeddings(model, text, url, key, user)
|
||||
else:
|
||||
embeddings = generate_openai_batch_embeddings(model, [text], url, key)
|
||||
embeddings = generate_openai_batch_embeddings(model, [text], url, key, user)
|
||||
|
||||
return embeddings[0] if isinstance(text, str) else embeddings
|
||||
|
||||
|
|
|
|||
|
|
@ -16,6 +16,10 @@ elif VECTOR_DB == "pgvector":
|
|||
from open_webui.retrieval.vector.dbs.pgvector import PgvectorClient
|
||||
|
||||
VECTOR_DB_CLIENT = PgvectorClient()
|
||||
elif VECTOR_DB == "elasticsearch":
|
||||
from open_webui.retrieval.vector.dbs.elasticsearch import ElasticsearchClient
|
||||
|
||||
VECTOR_DB_CLIENT = ElasticsearchClient()
|
||||
else:
|
||||
from open_webui.retrieval.vector.dbs.chroma import ChromaClient
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import chromadb
|
||||
import logging
|
||||
from chromadb import Settings
|
||||
from chromadb.utils.batch_utils import create_batches
|
||||
|
||||
|
|
@ -16,6 +17,10 @@ from open_webui.config import (
|
|||
CHROMA_CLIENT_AUTH_PROVIDER,
|
||||
CHROMA_CLIENT_AUTH_CREDENTIALS,
|
||||
)
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
class ChromaClient:
|
||||
|
|
@ -102,8 +107,7 @@ class ChromaClient:
|
|||
}
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
print(e)
|
||||
except:
|
||||
return None
|
||||
|
||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,274 @@
|
|||
from elasticsearch import Elasticsearch, BadRequestError
|
||||
from typing import Optional
|
||||
import ssl
|
||||
from elasticsearch.helpers import bulk, scan
|
||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.config import (
|
||||
ELASTICSEARCH_URL,
|
||||
ELASTICSEARCH_CA_CERTS,
|
||||
ELASTICSEARCH_API_KEY,
|
||||
ELASTICSEARCH_USERNAME,
|
||||
ELASTICSEARCH_PASSWORD,
|
||||
ELASTICSEARCH_CLOUD_ID,
|
||||
SSL_ASSERT_FINGERPRINT,
|
||||
)
|
||||
|
||||
|
||||
class ElasticsearchClient:
|
||||
"""
|
||||
Important:
|
||||
in order to reduce the number of indexes and since the embedding vector length is fixed, we avoid creating
|
||||
an index for each file but store it as a text field, while seperating to different index
|
||||
baesd on the embedding length.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.index_prefix = "open_webui_collections"
|
||||
self.client = Elasticsearch(
|
||||
hosts=[ELASTICSEARCH_URL],
|
||||
ca_certs=ELASTICSEARCH_CA_CERTS,
|
||||
api_key=ELASTICSEARCH_API_KEY,
|
||||
cloud_id=ELASTICSEARCH_CLOUD_ID,
|
||||
basic_auth=(
|
||||
(ELASTICSEARCH_USERNAME, ELASTICSEARCH_PASSWORD)
|
||||
if ELASTICSEARCH_USERNAME and ELASTICSEARCH_PASSWORD
|
||||
else None
|
||||
),
|
||||
ssl_assert_fingerprint=SSL_ASSERT_FINGERPRINT,
|
||||
)
|
||||
|
||||
# Status: works
|
||||
def _get_index_name(self, dimension: int) -> str:
|
||||
return f"{self.index_prefix}_d{str(dimension)}"
|
||||
|
||||
# Status: works
|
||||
def _scan_result_to_get_result(self, result) -> GetResult:
|
||||
if not result:
|
||||
return None
|
||||
ids = []
|
||||
documents = []
|
||||
metadatas = []
|
||||
|
||||
for hit in result:
|
||||
ids.append(hit["_id"])
|
||||
documents.append(hit["_source"].get("text"))
|
||||
metadatas.append(hit["_source"].get("metadata"))
|
||||
|
||||
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
|
||||
|
||||
# Status: works
|
||||
def _result_to_get_result(self, result) -> GetResult:
|
||||
if not result["hits"]["hits"]:
|
||||
return None
|
||||
ids = []
|
||||
documents = []
|
||||
metadatas = []
|
||||
|
||||
for hit in result["hits"]["hits"]:
|
||||
ids.append(hit["_id"])
|
||||
documents.append(hit["_source"].get("text"))
|
||||
metadatas.append(hit["_source"].get("metadata"))
|
||||
|
||||
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
|
||||
|
||||
# Status: works
|
||||
def _result_to_search_result(self, result) -> SearchResult:
|
||||
ids = []
|
||||
distances = []
|
||||
documents = []
|
||||
metadatas = []
|
||||
|
||||
for hit in result["hits"]["hits"]:
|
||||
ids.append(hit["_id"])
|
||||
distances.append(hit["_score"])
|
||||
documents.append(hit["_source"].get("text"))
|
||||
metadatas.append(hit["_source"].get("metadata"))
|
||||
|
||||
return SearchResult(
|
||||
ids=[ids],
|
||||
distances=[distances],
|
||||
documents=[documents],
|
||||
metadatas=[metadatas],
|
||||
)
|
||||
|
||||
# Status: works
|
||||
def _create_index(self, dimension: int):
|
||||
body = {
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"collection": {"type": "keyword"},
|
||||
"id": {"type": "keyword"},
|
||||
"vector": {
|
||||
"type": "dense_vector",
|
||||
"dims": dimension, # Adjust based on your vector dimensions
|
||||
"index": True,
|
||||
"similarity": "cosine",
|
||||
},
|
||||
"text": {"type": "text"},
|
||||
"metadata": {"type": "object"},
|
||||
}
|
||||
}
|
||||
}
|
||||
self.client.indices.create(index=self._get_index_name(dimension), body=body)
|
||||
|
||||
# Status: works
|
||||
|
||||
def _create_batches(self, items: list[VectorItem], batch_size=100):
|
||||
for i in range(0, len(items), batch_size):
|
||||
yield items[i : min(i + batch_size, len(items))]
|
||||
|
||||
# Status: works
|
||||
def has_collection(self, collection_name) -> bool:
|
||||
query_body = {"query": {"bool": {"filter": []}}}
|
||||
query_body["query"]["bool"]["filter"].append(
|
||||
{"term": {"collection": collection_name}}
|
||||
)
|
||||
|
||||
try:
|
||||
result = self.client.count(index=f"{self.index_prefix}*", body=query_body)
|
||||
|
||||
return result.body["count"] > 0
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
# @TODO: Make this delete a collection and not an index
|
||||
def delete_colleciton(self, collection_name: str):
|
||||
# TODO: fix this to include the dimension or a * prefix
|
||||
# delete_collection here means delete a bunch of documents for an index.
|
||||
# We are simply adapting to the norms of the other DBs.
|
||||
self.client.indices.delete(index=self._get_collection_name(collection_name))
|
||||
|
||||
# Status: works
|
||||
def search(
|
||||
self, collection_name: str, vectors: list[list[float]], limit: int
|
||||
) -> Optional[SearchResult]:
|
||||
query = {
|
||||
"size": limit,
|
||||
"_source": ["text", "metadata"],
|
||||
"query": {
|
||||
"script_score": {
|
||||
"query": {
|
||||
"bool": {"filter": [{"term": {"collection": collection_name}}]}
|
||||
},
|
||||
"script": {
|
||||
"source": "cosineSimilarity(params.vector, 'vector') + 1.0",
|
||||
"params": {
|
||||
"vector": vectors[0]
|
||||
}, # Assuming single query vector
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
result = self.client.search(
|
||||
index=self._get_index_name(len(vectors[0])), body=query
|
||||
)
|
||||
|
||||
return self._result_to_search_result(result)
|
||||
|
||||
# Status: only tested halfwat
|
||||
def query(
|
||||
self, collection_name: str, filter: dict, limit: Optional[int] = None
|
||||
) -> Optional[GetResult]:
|
||||
if not self.has_collection(collection_name):
|
||||
return None
|
||||
|
||||
query_body = {
|
||||
"query": {"bool": {"filter": []}},
|
||||
"_source": ["text", "metadata"],
|
||||
}
|
||||
|
||||
for field, value in filter.items():
|
||||
query_body["query"]["bool"]["filter"].append({"term": {field: value}})
|
||||
query_body["query"]["bool"]["filter"].append(
|
||||
{"term": {"collection": collection_name}}
|
||||
)
|
||||
size = limit if limit else 10
|
||||
|
||||
try:
|
||||
result = self.client.search(
|
||||
index=f"{self.index_prefix}*",
|
||||
body=query_body,
|
||||
size=size,
|
||||
)
|
||||
|
||||
return self._result_to_get_result(result)
|
||||
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
# Status: works
|
||||
def _has_index(self, dimension: int):
|
||||
return self.client.indices.exists(
|
||||
index=self._get_index_name(dimension=dimension)
|
||||
)
|
||||
|
||||
def get_or_create_index(self, dimension: int):
|
||||
if not self._has_index(dimension=dimension):
|
||||
self._create_index(dimension=dimension)
|
||||
|
||||
# Status: works
|
||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||
# Get all the items in the collection.
|
||||
query = {
|
||||
"query": {"bool": {"filter": [{"term": {"collection": collection_name}}]}},
|
||||
"_source": ["text", "metadata"],
|
||||
}
|
||||
results = list(scan(self.client, index=f"{self.index_prefix}*", query=query))
|
||||
|
||||
return self._scan_result_to_get_result(results)
|
||||
|
||||
# Status: works
|
||||
def insert(self, collection_name: str, items: list[VectorItem]):
|
||||
if not self._has_index(dimension=len(items[0]["vector"])):
|
||||
self._create_index(dimension=len(items[0]["vector"]))
|
||||
|
||||
for batch in self._create_batches(items):
|
||||
actions = [
|
||||
{
|
||||
"_index": self._get_index_name(dimension=len(items[0]["vector"])),
|
||||
"_id": item["id"],
|
||||
"_source": {
|
||||
"collection": collection_name,
|
||||
"vector": item["vector"],
|
||||
"text": item["text"],
|
||||
"metadata": item["metadata"],
|
||||
},
|
||||
}
|
||||
for item in batch
|
||||
]
|
||||
bulk(self.client, actions)
|
||||
|
||||
# Status: should work
|
||||
def upsert(self, collection_name: str, items: list[VectorItem]):
|
||||
if not self._has_index(dimension=len(items[0]["vector"])):
|
||||
self._create_index(collection_name, dimension=len(items[0]["vector"]))
|
||||
|
||||
for batch in self._create_batches(items):
|
||||
actions = [
|
||||
{
|
||||
"_index": self._get_index_name(dimension=len(items[0]["vector"])),
|
||||
"_id": item["id"],
|
||||
"_source": {
|
||||
"vector": item["vector"],
|
||||
"text": item["text"],
|
||||
"metadata": item["metadata"],
|
||||
},
|
||||
}
|
||||
for item in batch
|
||||
]
|
||||
self.client.bulk(actions)
|
||||
|
||||
# TODO: This currently deletes by * which is not always supported in ElasticSearch.
|
||||
# Need to read a bit before changing. Also, need to delete from a specific collection
|
||||
def delete(self, collection_name: str, ids: list[str]):
|
||||
# Assuming ID is unique across collections and indexes
|
||||
actions = [
|
||||
{"delete": {"_index": f"{self.index_prefix}*", "_id": id}} for id in ids
|
||||
]
|
||||
self.client.bulk(body=actions)
|
||||
|
||||
def reset(self):
|
||||
indices = self.client.indices.get(index=f"{self.index_prefix}*")
|
||||
for index in indices:
|
||||
self.client.indices.delete(index=index)
|
||||
|
|
@ -1,20 +1,28 @@
|
|||
from pymilvus import MilvusClient as Client
|
||||
from pymilvus import FieldSchema, DataType
|
||||
import json
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.config import (
|
||||
MILVUS_URI,
|
||||
MILVUS_DB,
|
||||
MILVUS_TOKEN,
|
||||
)
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
class MilvusClient:
|
||||
def __init__(self):
|
||||
self.collection_prefix = "open_webui"
|
||||
self.client = Client(uri=MILVUS_URI, database=MILVUS_DB)
|
||||
if MILVUS_TOKEN is None:
|
||||
self.client = Client(uri=MILVUS_URI, db_name=MILVUS_DB)
|
||||
else:
|
||||
self.client = Client(uri=MILVUS_URI, db_name=MILVUS_DB, token=MILVUS_TOKEN)
|
||||
|
||||
def _result_to_get_result(self, result) -> GetResult:
|
||||
ids = []
|
||||
|
|
@ -164,7 +172,7 @@ class MilvusClient:
|
|||
try:
|
||||
# Loop until there are no more items to fetch or the desired limit is reached
|
||||
while remaining > 0:
|
||||
print("remaining", remaining)
|
||||
log.info(f"remaining: {remaining}")
|
||||
current_fetch = min(
|
||||
max_limit, remaining
|
||||
) # Determine how many items to fetch in this iteration
|
||||
|
|
@ -191,10 +199,12 @@ class MilvusClient:
|
|||
if results_count < current_fetch:
|
||||
break
|
||||
|
||||
print(all_results)
|
||||
log.debug(all_results)
|
||||
return self._result_to_get_result([all_results])
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(
|
||||
f"Error querying collection {collection_name} with limit {limit}: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ class OpenSearchClient:
|
|||
ids=ids, distances=distances, documents=documents, metadatas=metadatas
|
||||
)
|
||||
|
||||
def _create_index(self, index_name: str, dimension: int):
|
||||
def _create_index(self, collection_name: str, dimension: int):
|
||||
body = {
|
||||
"mappings": {
|
||||
"properties": {
|
||||
|
|
@ -72,24 +72,28 @@ class OpenSearchClient:
|
|||
}
|
||||
}
|
||||
}
|
||||
self.client.indices.create(index=f"{self.index_prefix}_{index_name}", body=body)
|
||||
self.client.indices.create(
|
||||
index=f"{self.index_prefix}_{collection_name}", body=body
|
||||
)
|
||||
|
||||
def _create_batches(self, items: list[VectorItem], batch_size=100):
|
||||
for i in range(0, len(items), batch_size):
|
||||
yield items[i : i + batch_size]
|
||||
|
||||
def has_collection(self, index_name: str) -> bool:
|
||||
def has_collection(self, collection_name: str) -> bool:
|
||||
# has_collection here means has index.
|
||||
# We are simply adapting to the norms of the other DBs.
|
||||
return self.client.indices.exists(index=f"{self.index_prefix}_{index_name}")
|
||||
return self.client.indices.exists(
|
||||
index=f"{self.index_prefix}_{collection_name}"
|
||||
)
|
||||
|
||||
def delete_colleciton(self, index_name: str):
|
||||
def delete_colleciton(self, collection_name: str):
|
||||
# delete_collection here means delete index.
|
||||
# We are simply adapting to the norms of the other DBs.
|
||||
self.client.indices.delete(index=f"{self.index_prefix}_{index_name}")
|
||||
self.client.indices.delete(index=f"{self.index_prefix}_{collection_name}")
|
||||
|
||||
def search(
|
||||
self, index_name: str, vectors: list[list[float]], limit: int
|
||||
self, collection_name: str, vectors: list[list[float]], limit: int
|
||||
) -> Optional[SearchResult]:
|
||||
query = {
|
||||
"size": limit,
|
||||
|
|
@ -108,26 +112,55 @@ class OpenSearchClient:
|
|||
}
|
||||
|
||||
result = self.client.search(
|
||||
index=f"{self.index_prefix}_{index_name}", body=query
|
||||
index=f"{self.index_prefix}_{collection_name}", body=query
|
||||
)
|
||||
|
||||
return self._result_to_search_result(result)
|
||||
|
||||
def get_or_create_index(self, index_name: str, dimension: int):
|
||||
if not self.has_index(index_name):
|
||||
self._create_index(index_name, dimension)
|
||||
def query(
|
||||
self, collection_name: str, filter: dict, limit: Optional[int] = None
|
||||
) -> Optional[GetResult]:
|
||||
if not self.has_collection(collection_name):
|
||||
return None
|
||||
|
||||
def get(self, index_name: str) -> Optional[GetResult]:
|
||||
query_body = {
|
||||
"query": {"bool": {"filter": []}},
|
||||
"_source": ["text", "metadata"],
|
||||
}
|
||||
|
||||
for field, value in filter.items():
|
||||
query_body["query"]["bool"]["filter"].append({"term": {field: value}})
|
||||
|
||||
size = limit if limit else 10
|
||||
|
||||
try:
|
||||
result = self.client.search(
|
||||
index=f"{self.index_prefix}_{collection_name}",
|
||||
body=query_body,
|
||||
size=size,
|
||||
)
|
||||
|
||||
return self._result_to_get_result(result)
|
||||
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
def _create_index_if_not_exists(self, collection_name: str, dimension: int):
|
||||
if not self.has_index(collection_name):
|
||||
self._create_index(collection_name, dimension)
|
||||
|
||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||
query = {"query": {"match_all": {}}, "_source": ["text", "metadata"]}
|
||||
|
||||
result = self.client.search(
|
||||
index=f"{self.index_prefix}_{index_name}", body=query
|
||||
index=f"{self.index_prefix}_{collection_name}", body=query
|
||||
)
|
||||
return self._result_to_get_result(result)
|
||||
|
||||
def insert(self, index_name: str, items: list[VectorItem]):
|
||||
if not self.has_index(index_name):
|
||||
self._create_index(index_name, dimension=len(items[0]["vector"]))
|
||||
def insert(self, collection_name: str, items: list[VectorItem]):
|
||||
self._create_index_if_not_exists(
|
||||
collection_name=collection_name, dimension=len(items[0]["vector"])
|
||||
)
|
||||
|
||||
for batch in self._create_batches(items):
|
||||
actions = [
|
||||
|
|
@ -145,15 +178,17 @@ class OpenSearchClient:
|
|||
]
|
||||
self.client.bulk(actions)
|
||||
|
||||
def upsert(self, index_name: str, items: list[VectorItem]):
|
||||
if not self.has_index(index_name):
|
||||
self._create_index(index_name, dimension=len(items[0]["vector"]))
|
||||
def upsert(self, collection_name: str, items: list[VectorItem]):
|
||||
self._create_index_if_not_exists(
|
||||
collection_name=collection_name, dimension=len(items[0]["vector"])
|
||||
)
|
||||
|
||||
for batch in self._create_batches(items):
|
||||
actions = [
|
||||
{
|
||||
"index": {
|
||||
"_id": item["id"],
|
||||
"_index": f"{self.index_prefix}_{collection_name}",
|
||||
"_source": {
|
||||
"vector": item["vector"],
|
||||
"text": item["text"],
|
||||
|
|
@ -165,9 +200,9 @@ class OpenSearchClient:
|
|||
]
|
||||
self.client.bulk(actions)
|
||||
|
||||
def delete(self, index_name: str, ids: list[str]):
|
||||
def delete(self, collection_name: str, ids: list[str]):
|
||||
actions = [
|
||||
{"delete": {"_index": f"{self.index_prefix}_{index_name}", "_id": id}}
|
||||
{"delete": {"_index": f"{self.index_prefix}_{collection_name}", "_id": id}}
|
||||
for id in ids
|
||||
]
|
||||
self.client.bulk(body=actions)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from typing import Optional, List, Dict, Any
|
||||
import logging
|
||||
from sqlalchemy import (
|
||||
cast,
|
||||
column,
|
||||
|
|
@ -24,9 +25,14 @@ from sqlalchemy.exc import NoSuchTableError
|
|||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.config import PGVECTOR_DB_URL, PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
VECTOR_LENGTH = PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
|
||||
Base = declarative_base()
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
class DocumentChunk(Base):
|
||||
__tablename__ = "document_chunk"
|
||||
|
|
@ -82,10 +88,10 @@ class PgvectorClient:
|
|||
)
|
||||
)
|
||||
self.session.commit()
|
||||
print("Initialization complete.")
|
||||
log.info("Initialization complete.")
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
print(f"Error during initialization: {e}")
|
||||
log.exception(f"Error during initialization: {e}")
|
||||
raise
|
||||
|
||||
def check_vector_length(self) -> None:
|
||||
|
|
@ -150,12 +156,12 @@ class PgvectorClient:
|
|||
new_items.append(new_chunk)
|
||||
self.session.bulk_save_objects(new_items)
|
||||
self.session.commit()
|
||||
print(
|
||||
log.info(
|
||||
f"Inserted {len(new_items)} items into collection '{collection_name}'."
|
||||
)
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
print(f"Error during insert: {e}")
|
||||
log.exception(f"Error during insert: {e}")
|
||||
raise
|
||||
|
||||
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
|
|
@ -184,10 +190,12 @@ class PgvectorClient:
|
|||
)
|
||||
self.session.add(new_chunk)
|
||||
self.session.commit()
|
||||
print(f"Upserted {len(items)} items into collection '{collection_name}'.")
|
||||
log.info(
|
||||
f"Upserted {len(items)} items into collection '{collection_name}'."
|
||||
)
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
print(f"Error during upsert: {e}")
|
||||
log.exception(f"Error during upsert: {e}")
|
||||
raise
|
||||
|
||||
def search(
|
||||
|
|
@ -278,7 +286,7 @@ class PgvectorClient:
|
|||
ids=ids, distances=distances, documents=documents, metadatas=metadatas
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error during search: {e}")
|
||||
log.exception(f"Error during search: {e}")
|
||||
return None
|
||||
|
||||
def query(
|
||||
|
|
@ -310,7 +318,7 @@ class PgvectorClient:
|
|||
metadatas=metadatas,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error during query: {e}")
|
||||
log.exception(f"Error during query: {e}")
|
||||
return None
|
||||
|
||||
def get(
|
||||
|
|
@ -334,7 +342,7 @@ class PgvectorClient:
|
|||
|
||||
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
|
||||
except Exception as e:
|
||||
print(f"Error during get: {e}")
|
||||
log.exception(f"Error during get: {e}")
|
||||
return None
|
||||
|
||||
def delete(
|
||||
|
|
@ -356,22 +364,22 @@ class PgvectorClient:
|
|||
)
|
||||
deleted = query.delete(synchronize_session=False)
|
||||
self.session.commit()
|
||||
print(f"Deleted {deleted} items from collection '{collection_name}'.")
|
||||
log.info(f"Deleted {deleted} items from collection '{collection_name}'.")
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
print(f"Error during delete: {e}")
|
||||
log.exception(f"Error during delete: {e}")
|
||||
raise
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
deleted = self.session.query(DocumentChunk).delete()
|
||||
self.session.commit()
|
||||
print(
|
||||
log.info(
|
||||
f"Reset complete. Deleted {deleted} items from 'document_chunk' table."
|
||||
)
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
print(f"Error during reset: {e}")
|
||||
log.exception(f"Error during reset: {e}")
|
||||
raise
|
||||
|
||||
def close(self) -> None:
|
||||
|
|
@ -387,9 +395,9 @@ class PgvectorClient:
|
|||
)
|
||||
return exists
|
||||
except Exception as e:
|
||||
print(f"Error checking collection existence: {e}")
|
||||
log.exception(f"Error checking collection existence: {e}")
|
||||
return False
|
||||
|
||||
def delete_collection(self, collection_name: str) -> None:
|
||||
self.delete(collection_name)
|
||||
print(f"Collection '{collection_name}' deleted.")
|
||||
log.info(f"Collection '{collection_name}' deleted.")
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from typing import Optional
|
||||
import logging
|
||||
|
||||
from qdrant_client import QdrantClient as Qclient
|
||||
from qdrant_client.http.models import PointStruct
|
||||
|
|
@ -6,9 +7,13 @@ from qdrant_client.models import models
|
|||
|
||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.config import QDRANT_URI, QDRANT_API_KEY
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
NO_LIMIT = 999999999
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
class QdrantClient:
|
||||
def __init__(self):
|
||||
|
|
@ -49,7 +54,7 @@ class QdrantClient:
|
|||
),
|
||||
)
|
||||
|
||||
print(f"collection {collection_name_with_prefix} successfully created!")
|
||||
log.info(f"collection {collection_name_with_prefix} successfully created!")
|
||||
|
||||
def _create_collection_if_not_exists(self, collection_name, dimension):
|
||||
if not self.has_collection(collection_name=collection_name):
|
||||
|
|
@ -120,7 +125,7 @@ class QdrantClient:
|
|||
)
|
||||
return self._result_to_get_result(points.points)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Error querying a collection '{collection_name}': {e}")
|
||||
return None
|
||||
|
||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,65 @@
|
|||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
import json
|
||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def _parse_response(response):
|
||||
result = {}
|
||||
if "data" in response:
|
||||
data = response["data"]
|
||||
if "webPages" in data:
|
||||
webPages = data["webPages"]
|
||||
if "value" in webPages:
|
||||
result["webpage"] = [
|
||||
{
|
||||
"id": item.get("id", ""),
|
||||
"name": item.get("name", ""),
|
||||
"url": item.get("url", ""),
|
||||
"snippet": item.get("snippet", ""),
|
||||
"summary": item.get("summary", ""),
|
||||
"siteName": item.get("siteName", ""),
|
||||
"siteIcon": item.get("siteIcon", ""),
|
||||
"datePublished": item.get("datePublished", "")
|
||||
or item.get("dateLastCrawled", ""),
|
||||
}
|
||||
for item in webPages["value"]
|
||||
]
|
||||
return result
|
||||
|
||||
|
||||
def search_bocha(
|
||||
api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None
|
||||
) -> list[SearchResult]:
|
||||
"""Search using Bocha's Search API and return the results as a list of SearchResult objects.
|
||||
|
||||
Args:
|
||||
api_key (str): A Bocha Search API key
|
||||
query (str): The query to search for
|
||||
"""
|
||||
url = "https://api.bochaai.com/v1/web-search?utm_source=ollama"
|
||||
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
||||
|
||||
payload = json.dumps(
|
||||
{"query": query, "summary": True, "freshness": "noLimit", "count": count}
|
||||
)
|
||||
|
||||
response = requests.post(url, headers=headers, data=payload, timeout=5)
|
||||
response.raise_for_status()
|
||||
results = _parse_response(response.json())
|
||||
print(results)
|
||||
if filter_list:
|
||||
results = get_filtered_results(results, filter_list)
|
||||
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["url"], title=result.get("name"), snippet=result.get("summary")
|
||||
)
|
||||
for result in results.get("webpage", [])[:count]
|
||||
]
|
||||
|
|
@ -32,19 +32,15 @@ def search_duckduckgo(
|
|||
# Convert the search results into a list
|
||||
search_results = [r for r in ddgs_gen]
|
||||
|
||||
# Create an empty list to store the SearchResult objects
|
||||
results = []
|
||||
# Iterate over each search result
|
||||
for result in search_results:
|
||||
# Create a SearchResult object and append it to the results list
|
||||
results.append(
|
||||
SearchResult(
|
||||
link=result["href"],
|
||||
title=result.get("title"),
|
||||
snippet=result.get("body"),
|
||||
)
|
||||
)
|
||||
if filter_list:
|
||||
results = get_filtered_results(results, filter_list)
|
||||
search_results = get_filtered_results(search_results, filter_list)
|
||||
|
||||
# Return the list of search results
|
||||
return results
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["href"],
|
||||
title=result.get("title"),
|
||||
snippet=result.get("body"),
|
||||
)
|
||||
for result in search_results
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,76 @@
|
|||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from open_webui.retrieval.web.main import SearchResult
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
EXA_API_BASE = "https://api.exa.ai"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExaResult:
|
||||
url: str
|
||||
title: str
|
||||
text: str
|
||||
|
||||
|
||||
def search_exa(
|
||||
api_key: str,
|
||||
query: str,
|
||||
count: int,
|
||||
filter_list: Optional[list[str]] = None,
|
||||
) -> list[SearchResult]:
|
||||
"""Search using Exa Search API and return the results as a list of SearchResult objects.
|
||||
|
||||
Args:
|
||||
api_key (str): A Exa Search API key
|
||||
query (str): The query to search for
|
||||
count (int): Number of results to return
|
||||
filter_list (Optional[list[str]]): List of domains to filter results by
|
||||
"""
|
||||
log.info(f"Searching with Exa for query: {query}")
|
||||
|
||||
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
||||
|
||||
payload = {
|
||||
"query": query,
|
||||
"numResults": count or 5,
|
||||
"includeDomains": filter_list,
|
||||
"contents": {"text": True, "highlights": True},
|
||||
"type": "auto", # Use the auto search type (keyword or neural)
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{EXA_API_BASE}/search", headers=headers, json=payload
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
results = []
|
||||
for result in data["results"]:
|
||||
results.append(
|
||||
ExaResult(
|
||||
url=result["url"],
|
||||
title=result["title"],
|
||||
text=result["text"],
|
||||
)
|
||||
)
|
||||
|
||||
log.info(f"Found {len(results)} results")
|
||||
return [
|
||||
SearchResult(
|
||||
link=result.url,
|
||||
title=result.title,
|
||||
snippet=result.text,
|
||||
)
|
||||
for result in results
|
||||
]
|
||||
except Exception as e:
|
||||
log.error(f"Error searching Exa: {e}")
|
||||
return []
|
||||
|
|
@ -17,34 +17,53 @@ def search_google_pse(
|
|||
filter_list: Optional[list[str]] = None,
|
||||
) -> list[SearchResult]:
|
||||
"""Search using Google's Programmable Search Engine API and return the results as a list of SearchResult objects.
|
||||
Handles pagination for counts greater than 10.
|
||||
|
||||
Args:
|
||||
api_key (str): A Programmable Search Engine API key
|
||||
search_engine_id (str): A Programmable Search Engine ID
|
||||
query (str): The query to search for
|
||||
count (int): The number of results to return (max 100, as PSE max results per query is 10 and max page is 10)
|
||||
filter_list (Optional[list[str]], optional): A list of keywords to filter out from results. Defaults to None.
|
||||
|
||||
Returns:
|
||||
list[SearchResult]: A list of SearchResult objects.
|
||||
"""
|
||||
url = "https://www.googleapis.com/customsearch/v1"
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
params = {
|
||||
"cx": search_engine_id,
|
||||
"q": query,
|
||||
"key": api_key,
|
||||
"num": count,
|
||||
}
|
||||
all_results = []
|
||||
start_index = 1 # Google PSE start parameter is 1-based
|
||||
|
||||
response = requests.request("GET", url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
while count > 0:
|
||||
num_results_this_page = min(count, 10) # Google PSE max results per page is 10
|
||||
params = {
|
||||
"cx": search_engine_id,
|
||||
"q": query,
|
||||
"key": api_key,
|
||||
"num": num_results_this_page,
|
||||
"start": start_index,
|
||||
}
|
||||
response = requests.request("GET", url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
json_response = response.json()
|
||||
results = json_response.get("items", [])
|
||||
if results: # check if results are returned. If not, no more pages to fetch.
|
||||
all_results.extend(results)
|
||||
count -= len(
|
||||
results
|
||||
) # Decrement count by the number of results fetched in this page.
|
||||
start_index += 10 # Increment start index for the next page
|
||||
else:
|
||||
break # No more results from Google PSE, break the loop
|
||||
|
||||
json_response = response.json()
|
||||
results = json_response.get("items", [])
|
||||
if filter_list:
|
||||
results = get_filtered_results(results, filter_list)
|
||||
all_results = get_filtered_results(all_results, filter_list)
|
||||
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["link"],
|
||||
title=result.get("title"),
|
||||
snippet=result.get("snippet"),
|
||||
)
|
||||
for result in results
|
||||
for result in all_results
|
||||
]
|
||||
|
|
|
|||
|
|
@ -20,14 +20,23 @@ def search_jina(api_key: str, query: str, count: int) -> list[SearchResult]:
|
|||
list[SearchResult]: A list of search results
|
||||
"""
|
||||
jina_search_endpoint = "https://s.jina.ai/"
|
||||
headers = {"Accept": "application/json", "Authorization": f"Bearer {api_key}"}
|
||||
url = str(URL(jina_search_endpoint + query))
|
||||
response = requests.get(url, headers=headers)
|
||||
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": api_key,
|
||||
"X-Retain-Images": "none",
|
||||
}
|
||||
|
||||
payload = {"q": query, "count": count if count <= 10 else 10}
|
||||
|
||||
url = str(URL(jina_search_endpoint))
|
||||
response = requests.post(url, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
results = []
|
||||
for result in data["data"][:count]:
|
||||
for result in data["data"]:
|
||||
results.append(
|
||||
SearchResult(
|
||||
link=result["url"],
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
import validators
|
||||
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
|
@ -10,6 +12,8 @@ def get_filtered_results(results, filter_list):
|
|||
filtered_results = []
|
||||
for result in results:
|
||||
url = result.get("url") or result.get("link", "")
|
||||
if not validators.url(url):
|
||||
continue
|
||||
domain = urlparse(url).netloc
|
||||
if any(domain.endswith(filtered_domain) for filtered_domain in filter_list):
|
||||
filtered_results.append(result)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,87 @@
|
|||
import logging
|
||||
from typing import Optional, List
|
||||
import requests
|
||||
|
||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def search_perplexity(
|
||||
api_key: str,
|
||||
query: str,
|
||||
count: int,
|
||||
filter_list: Optional[list[str]] = None,
|
||||
) -> list[SearchResult]:
|
||||
"""Search using Perplexity API and return the results as a list of SearchResult objects.
|
||||
|
||||
Args:
|
||||
api_key (str): A Perplexity API key
|
||||
query (str): The query to search for
|
||||
count (int): Maximum number of results to return
|
||||
|
||||
"""
|
||||
|
||||
# Handle PersistentConfig object
|
||||
if hasattr(api_key, "__str__"):
|
||||
api_key = str(api_key)
|
||||
|
||||
try:
|
||||
url = "https://api.perplexity.ai/chat/completions"
|
||||
|
||||
# Create payload for the API call
|
||||
payload = {
|
||||
"model": "sonar",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a search assistant. Provide factual information with citations.",
|
||||
},
|
||||
{"role": "user", "content": query},
|
||||
],
|
||||
"temperature": 0.2, # Lower temperature for more factual responses
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# Make the API request
|
||||
response = requests.request("POST", url, json=payload, headers=headers)
|
||||
|
||||
# Parse the JSON response
|
||||
json_response = response.json()
|
||||
|
||||
# Extract citations from the response
|
||||
citations = json_response.get("citations", [])
|
||||
|
||||
# Create search results from citations
|
||||
results = []
|
||||
for i, citation in enumerate(citations[:count]):
|
||||
# Extract content from the response to use as snippet
|
||||
content = ""
|
||||
if "choices" in json_response and json_response["choices"]:
|
||||
if i == 0:
|
||||
content = json_response["choices"][0]["message"]["content"]
|
||||
|
||||
result = {"link": citation, "title": f"Source {i+1}", "snippet": content}
|
||||
results.append(result)
|
||||
|
||||
if filter_list:
|
||||
|
||||
results = get_filtered_results(results, filter_list)
|
||||
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["link"], title=result["title"], snippet=result["snippet"]
|
||||
)
|
||||
for result in results[:count]
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error searching with Perplexity API: {e}")
|
||||
return []
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
import logging
|
||||
from typing import Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import requests
|
||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def search_serpapi(
|
||||
api_key: str,
|
||||
engine: str,
|
||||
query: str,
|
||||
count: int,
|
||||
filter_list: Optional[list[str]] = None,
|
||||
) -> list[SearchResult]:
|
||||
"""Search using serpapi.com's API and return the results as a list of SearchResult objects.
|
||||
|
||||
Args:
|
||||
api_key (str): A serpapi.com API key
|
||||
query (str): The query to search for
|
||||
"""
|
||||
url = "https://serpapi.com/search"
|
||||
|
||||
engine = engine or "google"
|
||||
|
||||
payload = {"engine": engine, "q": query, "api_key": api_key}
|
||||
|
||||
url = f"{url}?{urlencode(payload)}"
|
||||
response = requests.request("GET", url)
|
||||
|
||||
json_response = response.json()
|
||||
log.info(f"results from serpapi search: {json_response}")
|
||||
|
||||
results = sorted(
|
||||
json_response.get("organic_results", []), key=lambda x: x.get("position", 0)
|
||||
)
|
||||
if filter_list:
|
||||
results = get_filtered_results(results, filter_list)
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["link"], title=result["title"], snippet=result["snippet"]
|
||||
)
|
||||
for result in results[:count]
|
||||
]
|
||||
|
|
@ -1,4 +1,5 @@
|
|||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from open_webui.retrieval.web.main import SearchResult
|
||||
|
|
@ -8,7 +9,13 @@ log = logging.getLogger(__name__)
|
|||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def search_tavily(api_key: str, query: str, count: int) -> list[SearchResult]:
|
||||
def search_tavily(
|
||||
api_key: str,
|
||||
query: str,
|
||||
count: int,
|
||||
filter_list: Optional[list[str]] = None,
|
||||
# **kwargs,
|
||||
) -> list[SearchResult]:
|
||||
"""Search using Tavily's Search API and return the results as a list of SearchResult objects.
|
||||
|
||||
Args:
|
||||
|
|
@ -20,7 +27,6 @@ def search_tavily(api_key: str, query: str, count: int) -> list[SearchResult]:
|
|||
"""
|
||||
url = "https://api.tavily.com/search"
|
||||
data = {"query": query, "api_key": api_key}
|
||||
|
||||
response = requests.post(url, json=data)
|
||||
response.raise_for_status()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,19 +1,38 @@
|
|||
import socket
|
||||
import urllib.parse
|
||||
import validators
|
||||
from typing import Union, Sequence, Iterator
|
||||
|
||||
from langchain_community.document_loaders import (
|
||||
WebBaseLoader,
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
|
||||
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.config import ENABLE_RAG_LOCAL_WEB_FETCH
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import socket
|
||||
import ssl
|
||||
import urllib.parse
|
||||
import urllib.request
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, time, timedelta
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
Literal,
|
||||
)
|
||||
import aiohttp
|
||||
import certifi
|
||||
import validators
|
||||
from langchain_community.document_loaders import PlaywrightURLLoader, WebBaseLoader
|
||||
from langchain_community.document_loaders.firecrawl import FireCrawlLoader
|
||||
from langchain_community.document_loaders.base import BaseLoader
|
||||
from langchain_core.documents import Document
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.config import (
|
||||
ENABLE_RAG_LOCAL_WEB_FETCH,
|
||||
PLAYWRIGHT_WS_URI,
|
||||
RAG_WEB_LOADER_ENGINE,
|
||||
FIRECRAWL_API_BASE_URL,
|
||||
FIRECRAWL_API_KEY,
|
||||
)
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
|
@ -43,6 +62,17 @@ def validate_url(url: Union[str, Sequence[str]]):
|
|||
return False
|
||||
|
||||
|
||||
def safe_validate_urls(url: Sequence[str]) -> Sequence[str]:
|
||||
valid_urls = []
|
||||
for u in url:
|
||||
try:
|
||||
if validate_url(u):
|
||||
valid_urls.append(u)
|
||||
except ValueError:
|
||||
continue
|
||||
return valid_urls
|
||||
|
||||
|
||||
def resolve_hostname(hostname):
|
||||
# Get address information
|
||||
addr_info = socket.getaddrinfo(hostname, None)
|
||||
|
|
@ -54,9 +84,381 @@ def resolve_hostname(hostname):
|
|||
return ipv4_addresses, ipv6_addresses
|
||||
|
||||
|
||||
def extract_metadata(soup, url):
|
||||
metadata = {"source": url}
|
||||
if title := soup.find("title"):
|
||||
metadata["title"] = title.get_text()
|
||||
if description := soup.find("meta", attrs={"name": "description"}):
|
||||
metadata["description"] = description.get("content", "No description found.")
|
||||
if html := soup.find("html"):
|
||||
metadata["language"] = html.get("lang", "No language found.")
|
||||
return metadata
|
||||
|
||||
|
||||
def verify_ssl_cert(url: str) -> bool:
|
||||
"""Verify SSL certificate for the given URL."""
|
||||
if not url.startswith("https://"):
|
||||
return True
|
||||
|
||||
try:
|
||||
hostname = url.split("://")[-1].split("/")[0]
|
||||
context = ssl.create_default_context(cafile=certifi.where())
|
||||
with context.wrap_socket(ssl.socket(), server_hostname=hostname) as s:
|
||||
s.connect((hostname, 443))
|
||||
return True
|
||||
except ssl.SSLError:
|
||||
return False
|
||||
except Exception as e:
|
||||
log.warning(f"SSL verification failed for {url}: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
class SafeFireCrawlLoader(BaseLoader):
|
||||
def __init__(
|
||||
self,
|
||||
web_paths,
|
||||
verify_ssl: bool = True,
|
||||
trust_env: bool = False,
|
||||
requests_per_second: Optional[float] = None,
|
||||
continue_on_failure: bool = True,
|
||||
api_key: Optional[str] = None,
|
||||
api_url: Optional[str] = None,
|
||||
mode: Literal["crawl", "scrape", "map"] = "crawl",
|
||||
proxy: Optional[Dict[str, str]] = None,
|
||||
params: Optional[Dict] = None,
|
||||
):
|
||||
"""Concurrent document loader for FireCrawl operations.
|
||||
|
||||
Executes multiple FireCrawlLoader instances concurrently using thread pooling
|
||||
to improve bulk processing efficiency.
|
||||
Args:
|
||||
web_paths: List of URLs/paths to process.
|
||||
verify_ssl: If True, verify SSL certificates.
|
||||
trust_env: If True, use proxy settings from environment variables.
|
||||
requests_per_second: Number of requests per second to limit to.
|
||||
continue_on_failure (bool): If True, continue loading other URLs on failure.
|
||||
api_key: API key for FireCrawl service. Defaults to None
|
||||
(uses FIRE_CRAWL_API_KEY environment variable if not provided).
|
||||
api_url: Base URL for FireCrawl API. Defaults to official API endpoint.
|
||||
mode: Operation mode selection:
|
||||
- 'crawl': Website crawling mode (default)
|
||||
- 'scrape': Direct page scraping
|
||||
- 'map': Site map generation
|
||||
proxy: Proxy override settings for the FireCrawl API.
|
||||
params: The parameters to pass to the Firecrawl API.
|
||||
Examples include crawlerOptions.
|
||||
For more details, visit: https://github.com/mendableai/firecrawl-py
|
||||
"""
|
||||
proxy_server = proxy.get("server") if proxy else None
|
||||
if trust_env and not proxy_server:
|
||||
env_proxies = urllib.request.getproxies()
|
||||
env_proxy_server = env_proxies.get("https") or env_proxies.get("http")
|
||||
if env_proxy_server:
|
||||
if proxy:
|
||||
proxy["server"] = env_proxy_server
|
||||
else:
|
||||
proxy = {"server": env_proxy_server}
|
||||
self.web_paths = web_paths
|
||||
self.verify_ssl = verify_ssl
|
||||
self.requests_per_second = requests_per_second
|
||||
self.last_request_time = None
|
||||
self.trust_env = trust_env
|
||||
self.continue_on_failure = continue_on_failure
|
||||
self.api_key = api_key
|
||||
self.api_url = api_url
|
||||
self.mode = mode
|
||||
self.params = params
|
||||
|
||||
def lazy_load(self) -> Iterator[Document]:
|
||||
"""Load documents concurrently using FireCrawl."""
|
||||
for url in self.web_paths:
|
||||
try:
|
||||
self._safe_process_url_sync(url)
|
||||
loader = FireCrawlLoader(
|
||||
url=url,
|
||||
api_key=self.api_key,
|
||||
api_url=self.api_url,
|
||||
mode=self.mode,
|
||||
params=self.params,
|
||||
)
|
||||
yield from loader.lazy_load()
|
||||
except Exception as e:
|
||||
if self.continue_on_failure:
|
||||
log.exception(e, "Error loading %s", url)
|
||||
continue
|
||||
raise e
|
||||
|
||||
async def alazy_load(self):
|
||||
"""Async version of lazy_load."""
|
||||
for url in self.web_paths:
|
||||
try:
|
||||
await self._safe_process_url(url)
|
||||
loader = FireCrawlLoader(
|
||||
url=url,
|
||||
api_key=self.api_key,
|
||||
api_url=self.api_url,
|
||||
mode=self.mode,
|
||||
params=self.params,
|
||||
)
|
||||
async for document in loader.alazy_load():
|
||||
yield document
|
||||
except Exception as e:
|
||||
if self.continue_on_failure:
|
||||
log.exception(e, "Error loading %s", url)
|
||||
continue
|
||||
raise e
|
||||
|
||||
def _verify_ssl_cert(self, url: str) -> bool:
|
||||
return verify_ssl_cert(url)
|
||||
|
||||
async def _wait_for_rate_limit(self):
|
||||
"""Wait to respect the rate limit if specified."""
|
||||
if self.requests_per_second and self.last_request_time:
|
||||
min_interval = timedelta(seconds=1.0 / self.requests_per_second)
|
||||
time_since_last = datetime.now() - self.last_request_time
|
||||
if time_since_last < min_interval:
|
||||
await asyncio.sleep((min_interval - time_since_last).total_seconds())
|
||||
self.last_request_time = datetime.now()
|
||||
|
||||
def _sync_wait_for_rate_limit(self):
|
||||
"""Synchronous version of rate limit wait."""
|
||||
if self.requests_per_second and self.last_request_time:
|
||||
min_interval = timedelta(seconds=1.0 / self.requests_per_second)
|
||||
time_since_last = datetime.now() - self.last_request_time
|
||||
if time_since_last < min_interval:
|
||||
time.sleep((min_interval - time_since_last).total_seconds())
|
||||
self.last_request_time = datetime.now()
|
||||
|
||||
async def _safe_process_url(self, url: str) -> bool:
|
||||
"""Perform safety checks before processing a URL."""
|
||||
if self.verify_ssl and not self._verify_ssl_cert(url):
|
||||
raise ValueError(f"SSL certificate verification failed for {url}")
|
||||
await self._wait_for_rate_limit()
|
||||
return True
|
||||
|
||||
def _safe_process_url_sync(self, url: str) -> bool:
|
||||
"""Synchronous version of safety checks."""
|
||||
if self.verify_ssl and not self._verify_ssl_cert(url):
|
||||
raise ValueError(f"SSL certificate verification failed for {url}")
|
||||
self._sync_wait_for_rate_limit()
|
||||
return True
|
||||
|
||||
|
||||
class SafePlaywrightURLLoader(PlaywrightURLLoader):
|
||||
"""Load HTML pages safely with Playwright, supporting SSL verification, rate limiting, and remote browser connection.
|
||||
|
||||
Attributes:
|
||||
web_paths (List[str]): List of URLs to load.
|
||||
verify_ssl (bool): If True, verify SSL certificates.
|
||||
trust_env (bool): If True, use proxy settings from environment variables.
|
||||
requests_per_second (Optional[float]): Number of requests per second to limit to.
|
||||
continue_on_failure (bool): If True, continue loading other URLs on failure.
|
||||
headless (bool): If True, the browser will run in headless mode.
|
||||
proxy (dict): Proxy override settings for the Playwright session.
|
||||
playwright_ws_url (Optional[str]): WebSocket endpoint URI for remote browser connection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
web_paths: List[str],
|
||||
verify_ssl: bool = True,
|
||||
trust_env: bool = False,
|
||||
requests_per_second: Optional[float] = None,
|
||||
continue_on_failure: bool = True,
|
||||
headless: bool = True,
|
||||
remove_selectors: Optional[List[str]] = None,
|
||||
proxy: Optional[Dict[str, str]] = None,
|
||||
playwright_ws_url: Optional[str] = None,
|
||||
):
|
||||
"""Initialize with additional safety parameters and remote browser support."""
|
||||
|
||||
proxy_server = proxy.get("server") if proxy else None
|
||||
if trust_env and not proxy_server:
|
||||
env_proxies = urllib.request.getproxies()
|
||||
env_proxy_server = env_proxies.get("https") or env_proxies.get("http")
|
||||
if env_proxy_server:
|
||||
if proxy:
|
||||
proxy["server"] = env_proxy_server
|
||||
else:
|
||||
proxy = {"server": env_proxy_server}
|
||||
|
||||
# We'll set headless to False if using playwright_ws_url since it's handled by the remote browser
|
||||
super().__init__(
|
||||
urls=web_paths,
|
||||
continue_on_failure=continue_on_failure,
|
||||
headless=headless if playwright_ws_url is None else False,
|
||||
remove_selectors=remove_selectors,
|
||||
proxy=proxy,
|
||||
)
|
||||
self.verify_ssl = verify_ssl
|
||||
self.requests_per_second = requests_per_second
|
||||
self.last_request_time = None
|
||||
self.playwright_ws_url = playwright_ws_url
|
||||
self.trust_env = trust_env
|
||||
|
||||
def lazy_load(self) -> Iterator[Document]:
|
||||
"""Safely load URLs synchronously with support for remote browser."""
|
||||
from playwright.sync_api import sync_playwright
|
||||
|
||||
with sync_playwright() as p:
|
||||
# Use remote browser if ws_endpoint is provided, otherwise use local browser
|
||||
if self.playwright_ws_url:
|
||||
browser = p.chromium.connect(self.playwright_ws_url)
|
||||
else:
|
||||
browser = p.chromium.launch(headless=self.headless, proxy=self.proxy)
|
||||
|
||||
for url in self.urls:
|
||||
try:
|
||||
self._safe_process_url_sync(url)
|
||||
page = browser.new_page()
|
||||
response = page.goto(url)
|
||||
if response is None:
|
||||
raise ValueError(f"page.goto() returned None for url {url}")
|
||||
|
||||
text = self.evaluator.evaluate(page, browser, response)
|
||||
metadata = {"source": url}
|
||||
yield Document(page_content=text, metadata=metadata)
|
||||
except Exception as e:
|
||||
if self.continue_on_failure:
|
||||
log.exception(e, "Error loading %s", url)
|
||||
continue
|
||||
raise e
|
||||
browser.close()
|
||||
|
||||
async def alazy_load(self) -> AsyncIterator[Document]:
|
||||
"""Safely load URLs asynchronously with support for remote browser."""
|
||||
from playwright.async_api import async_playwright
|
||||
|
||||
async with async_playwright() as p:
|
||||
# Use remote browser if ws_endpoint is provided, otherwise use local browser
|
||||
if self.playwright_ws_url:
|
||||
browser = await p.chromium.connect(self.playwright_ws_url)
|
||||
else:
|
||||
browser = await p.chromium.launch(
|
||||
headless=self.headless, proxy=self.proxy
|
||||
)
|
||||
|
||||
for url in self.urls:
|
||||
try:
|
||||
await self._safe_process_url(url)
|
||||
page = await browser.new_page()
|
||||
response = await page.goto(url)
|
||||
if response is None:
|
||||
raise ValueError(f"page.goto() returned None for url {url}")
|
||||
|
||||
text = await self.evaluator.evaluate_async(page, browser, response)
|
||||
metadata = {"source": url}
|
||||
yield Document(page_content=text, metadata=metadata)
|
||||
except Exception as e:
|
||||
if self.continue_on_failure:
|
||||
log.exception(e, "Error loading %s", url)
|
||||
continue
|
||||
raise e
|
||||
await browser.close()
|
||||
|
||||
def _verify_ssl_cert(self, url: str) -> bool:
|
||||
return verify_ssl_cert(url)
|
||||
|
||||
async def _wait_for_rate_limit(self):
|
||||
"""Wait to respect the rate limit if specified."""
|
||||
if self.requests_per_second and self.last_request_time:
|
||||
min_interval = timedelta(seconds=1.0 / self.requests_per_second)
|
||||
time_since_last = datetime.now() - self.last_request_time
|
||||
if time_since_last < min_interval:
|
||||
await asyncio.sleep((min_interval - time_since_last).total_seconds())
|
||||
self.last_request_time = datetime.now()
|
||||
|
||||
def _sync_wait_for_rate_limit(self):
|
||||
"""Synchronous version of rate limit wait."""
|
||||
if self.requests_per_second and self.last_request_time:
|
||||
min_interval = timedelta(seconds=1.0 / self.requests_per_second)
|
||||
time_since_last = datetime.now() - self.last_request_time
|
||||
if time_since_last < min_interval:
|
||||
time.sleep((min_interval - time_since_last).total_seconds())
|
||||
self.last_request_time = datetime.now()
|
||||
|
||||
async def _safe_process_url(self, url: str) -> bool:
|
||||
"""Perform safety checks before processing a URL."""
|
||||
if self.verify_ssl and not self._verify_ssl_cert(url):
|
||||
raise ValueError(f"SSL certificate verification failed for {url}")
|
||||
await self._wait_for_rate_limit()
|
||||
return True
|
||||
|
||||
def _safe_process_url_sync(self, url: str) -> bool:
|
||||
"""Synchronous version of safety checks."""
|
||||
if self.verify_ssl and not self._verify_ssl_cert(url):
|
||||
raise ValueError(f"SSL certificate verification failed for {url}")
|
||||
self._sync_wait_for_rate_limit()
|
||||
return True
|
||||
|
||||
|
||||
class SafeWebBaseLoader(WebBaseLoader):
|
||||
"""WebBaseLoader with enhanced error handling for URLs."""
|
||||
|
||||
def __init__(self, trust_env: bool = False, *args, **kwargs):
|
||||
"""Initialize SafeWebBaseLoader
|
||||
Args:
|
||||
trust_env (bool, optional): set to True if using proxy to make web requests, for example
|
||||
using http(s)_proxy environment variables. Defaults to False.
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
self.trust_env = trust_env
|
||||
|
||||
async def _fetch(
|
||||
self, url: str, retries: int = 3, cooldown: int = 2, backoff: float = 1.5
|
||||
) -> str:
|
||||
async with aiohttp.ClientSession(trust_env=self.trust_env) as session:
|
||||
for i in range(retries):
|
||||
try:
|
||||
kwargs: Dict = dict(
|
||||
headers=self.session.headers,
|
||||
cookies=self.session.cookies.get_dict(),
|
||||
)
|
||||
if not self.session.verify:
|
||||
kwargs["ssl"] = False
|
||||
|
||||
async with session.get(
|
||||
url, **(self.requests_kwargs | kwargs)
|
||||
) as response:
|
||||
if self.raise_for_status:
|
||||
response.raise_for_status()
|
||||
return await response.text()
|
||||
except aiohttp.ClientConnectionError as e:
|
||||
if i == retries - 1:
|
||||
raise
|
||||
else:
|
||||
log.warning(
|
||||
f"Error fetching {url} with attempt "
|
||||
f"{i + 1}/{retries}: {e}. Retrying..."
|
||||
)
|
||||
await asyncio.sleep(cooldown * backoff**i)
|
||||
raise ValueError("retry count exceeded")
|
||||
|
||||
def _unpack_fetch_results(
|
||||
self, results: Any, urls: List[str], parser: Union[str, None] = None
|
||||
) -> List[Any]:
|
||||
"""Unpack fetch results into BeautifulSoup objects."""
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
final_results = []
|
||||
for i, result in enumerate(results):
|
||||
url = urls[i]
|
||||
if parser is None:
|
||||
if url.endswith(".xml"):
|
||||
parser = "xml"
|
||||
else:
|
||||
parser = self.default_parser
|
||||
self._check_parser(parser)
|
||||
final_results.append(BeautifulSoup(result, parser, **self.bs_kwargs))
|
||||
return final_results
|
||||
|
||||
async def ascrape_all(
|
||||
self, urls: List[str], parser: Union[str, None] = None
|
||||
) -> List[Any]:
|
||||
"""Async fetch all urls, then return soups for all results."""
|
||||
results = await self.fetch_all(urls)
|
||||
return self._unpack_fetch_results(results, urls, parser=parser)
|
||||
|
||||
def lazy_load(self) -> Iterator[Document]:
|
||||
"""Lazy load text from the url(s) in web_path with error handling."""
|
||||
for path in self.web_paths:
|
||||
|
|
@ -65,33 +467,72 @@ class SafeWebBaseLoader(WebBaseLoader):
|
|||
text = soup.get_text(**self.bs_get_text_kwargs)
|
||||
|
||||
# Build metadata
|
||||
metadata = {"source": path}
|
||||
if title := soup.find("title"):
|
||||
metadata["title"] = title.get_text()
|
||||
if description := soup.find("meta", attrs={"name": "description"}):
|
||||
metadata["description"] = description.get(
|
||||
"content", "No description found."
|
||||
)
|
||||
if html := soup.find("html"):
|
||||
metadata["language"] = html.get("lang", "No language found.")
|
||||
metadata = extract_metadata(soup, path)
|
||||
|
||||
yield Document(page_content=text, metadata=metadata)
|
||||
except Exception as e:
|
||||
# Log the error and continue with the next URL
|
||||
log.error(f"Error loading {path}: {e}")
|
||||
log.exception(e, "Error loading %s", path)
|
||||
|
||||
async def alazy_load(self) -> AsyncIterator[Document]:
|
||||
"""Async lazy load text from the url(s) in web_path."""
|
||||
results = await self.ascrape_all(self.web_paths)
|
||||
for path, soup in zip(self.web_paths, results):
|
||||
text = soup.get_text(**self.bs_get_text_kwargs)
|
||||
metadata = {"source": path}
|
||||
if title := soup.find("title"):
|
||||
metadata["title"] = title.get_text()
|
||||
if description := soup.find("meta", attrs={"name": "description"}):
|
||||
metadata["description"] = description.get(
|
||||
"content", "No description found."
|
||||
)
|
||||
if html := soup.find("html"):
|
||||
metadata["language"] = html.get("lang", "No language found.")
|
||||
yield Document(page_content=text, metadata=metadata)
|
||||
|
||||
async def aload(self) -> list[Document]:
|
||||
"""Load data into Document objects."""
|
||||
return [document async for document in self.alazy_load()]
|
||||
|
||||
|
||||
RAG_WEB_LOADER_ENGINES = defaultdict(lambda: SafeWebBaseLoader)
|
||||
RAG_WEB_LOADER_ENGINES["playwright"] = SafePlaywrightURLLoader
|
||||
RAG_WEB_LOADER_ENGINES["safe_web"] = SafeWebBaseLoader
|
||||
RAG_WEB_LOADER_ENGINES["firecrawl"] = SafeFireCrawlLoader
|
||||
|
||||
|
||||
def get_web_loader(
|
||||
urls: Union[str, Sequence[str]],
|
||||
verify_ssl: bool = True,
|
||||
requests_per_second: int = 2,
|
||||
trust_env: bool = False,
|
||||
):
|
||||
# Check if the URL is valid
|
||||
if not validate_url(urls):
|
||||
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
||||
return SafeWebBaseLoader(
|
||||
urls,
|
||||
verify_ssl=verify_ssl,
|
||||
requests_per_second=requests_per_second,
|
||||
continue_on_failure=True,
|
||||
# Check if the URLs are valid
|
||||
safe_urls = safe_validate_urls([urls] if isinstance(urls, str) else urls)
|
||||
|
||||
web_loader_args = {
|
||||
"web_paths": safe_urls,
|
||||
"verify_ssl": verify_ssl,
|
||||
"requests_per_second": requests_per_second,
|
||||
"continue_on_failure": True,
|
||||
"trust_env": trust_env,
|
||||
}
|
||||
|
||||
if PLAYWRIGHT_WS_URI.value:
|
||||
web_loader_args["playwright_ws_url"] = PLAYWRIGHT_WS_URI.value
|
||||
|
||||
if RAG_WEB_LOADER_ENGINE.value == "firecrawl":
|
||||
web_loader_args["api_key"] = FIRECRAWL_API_KEY.value
|
||||
web_loader_args["api_url"] = FIRECRAWL_API_BASE_URL.value
|
||||
|
||||
# Create the appropriate WebLoader based on the configuration
|
||||
WebLoaderClass = RAG_WEB_LOADER_ENGINES[RAG_WEB_LOADER_ENGINE.value]
|
||||
web_loader = WebLoaderClass(**web_loader_args)
|
||||
|
||||
log.debug(
|
||||
"Using RAG_WEB_LOADER_ENGINE %s for %s URLs",
|
||||
web_loader.__class__.__name__,
|
||||
len(safe_urls),
|
||||
)
|
||||
|
||||
return web_loader
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from pydub.silence import split_on_silence
|
|||
import aiohttp
|
||||
import aiofiles
|
||||
import requests
|
||||
import mimetypes
|
||||
|
||||
from fastapi import (
|
||||
Depends,
|
||||
|
|
@ -36,6 +37,7 @@ from open_webui.config import (
|
|||
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import (
|
||||
AIOHTTP_CLIENT_TIMEOUT,
|
||||
ENV,
|
||||
SRC_LOG_LEVELS,
|
||||
DEVICE_TYPE,
|
||||
|
|
@ -52,7 +54,7 @@ MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
|
|||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["AUDIO"])
|
||||
|
||||
SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
|
||||
SPEECH_CACHE_DIR = CACHE_DIR / "audio" / "speech"
|
||||
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
|
|
@ -69,7 +71,7 @@ from pydub.utils import mediainfo
|
|||
def is_mp4_audio(file_path):
|
||||
"""Check if the given file is an MP4 audio file."""
|
||||
if not os.path.isfile(file_path):
|
||||
print(f"File not found: {file_path}")
|
||||
log.error(f"File not found: {file_path}")
|
||||
return False
|
||||
|
||||
info = mediainfo(file_path)
|
||||
|
|
@ -86,7 +88,7 @@ def convert_mp4_to_wav(file_path, output_path):
|
|||
"""Convert MP4 audio file to WAV format."""
|
||||
audio = AudioSegment.from_file(file_path, format="mp4")
|
||||
audio.export(output_path, format="wav")
|
||||
print(f"Converted {file_path} to {output_path}")
|
||||
log.info(f"Converted {file_path} to {output_path}")
|
||||
|
||||
|
||||
def set_faster_whisper_model(model: str, auto_update: bool = False):
|
||||
|
|
@ -138,6 +140,7 @@ class STTConfigForm(BaseModel):
|
|||
ENGINE: str
|
||||
MODEL: str
|
||||
WHISPER_MODEL: str
|
||||
DEEPGRAM_API_KEY: str
|
||||
|
||||
|
||||
class AudioConfigUpdateForm(BaseModel):
|
||||
|
|
@ -165,6 +168,7 @@ async def get_audio_config(request: Request, user=Depends(get_admin_user)):
|
|||
"ENGINE": request.app.state.config.STT_ENGINE,
|
||||
"MODEL": request.app.state.config.STT_MODEL,
|
||||
"WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
|
||||
"DEEPGRAM_API_KEY": request.app.state.config.DEEPGRAM_API_KEY,
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -190,6 +194,7 @@ async def update_audio_config(
|
|||
request.app.state.config.STT_ENGINE = form_data.stt.ENGINE
|
||||
request.app.state.config.STT_MODEL = form_data.stt.MODEL
|
||||
request.app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL
|
||||
request.app.state.config.DEEPGRAM_API_KEY = form_data.stt.DEEPGRAM_API_KEY
|
||||
|
||||
if request.app.state.config.STT_ENGINE == "":
|
||||
request.app.state.faster_whisper_model = set_faster_whisper_model(
|
||||
|
|
@ -214,6 +219,7 @@ async def update_audio_config(
|
|||
"ENGINE": request.app.state.config.STT_ENGINE,
|
||||
"MODEL": request.app.state.config.STT_MODEL,
|
||||
"WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
|
||||
"DEEPGRAM_API_KEY": request.app.state.config.DEEPGRAM_API_KEY,
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -260,8 +266,10 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
|||
payload["model"] = request.app.state.config.TTS_MODEL
|
||||
|
||||
try:
|
||||
# print(payload)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=timeout, trust_env=True
|
||||
) as session:
|
||||
async with session.post(
|
||||
url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
|
||||
json=payload,
|
||||
|
|
@ -318,7 +326,10 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
|||
)
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=timeout, trust_env=True
|
||||
) as session:
|
||||
async with session.post(
|
||||
f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}",
|
||||
json={
|
||||
|
|
@ -375,7 +386,10 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
|||
data = f"""<speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis" xml:lang="{locale}">
|
||||
<voice name="{language}">{payload["input"]}</voice>
|
||||
</speak>"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=timeout, trust_env=True
|
||||
) as session:
|
||||
async with session.post(
|
||||
f"https://{region}.tts.speech.microsoft.com/cognitiveservices/v1",
|
||||
headers={
|
||||
|
|
@ -453,7 +467,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
|||
|
||||
|
||||
def transcribe(request: Request, file_path):
|
||||
print("transcribe", file_path)
|
||||
log.info(f"transcribe: {file_path}")
|
||||
filename = os.path.basename(file_path)
|
||||
file_dir = os.path.dirname(file_path)
|
||||
id = filename.split(".")[0]
|
||||
|
|
@ -521,6 +535,69 @@ def transcribe(request: Request, file_path):
|
|||
|
||||
raise Exception(detail if detail else "Open WebUI: Server Connection Error")
|
||||
|
||||
elif request.app.state.config.STT_ENGINE == "deepgram":
|
||||
try:
|
||||
# Determine the MIME type of the file
|
||||
mime, _ = mimetypes.guess_type(file_path)
|
||||
if not mime:
|
||||
mime = "audio/wav" # fallback to wav if undetectable
|
||||
|
||||
# Read the audio file
|
||||
with open(file_path, "rb") as f:
|
||||
file_data = f.read()
|
||||
|
||||
# Build headers and parameters
|
||||
headers = {
|
||||
"Authorization": f"Token {request.app.state.config.DEEPGRAM_API_KEY}",
|
||||
"Content-Type": mime,
|
||||
}
|
||||
|
||||
# Add model if specified
|
||||
params = {}
|
||||
if request.app.state.config.STT_MODEL:
|
||||
params["model"] = request.app.state.config.STT_MODEL
|
||||
|
||||
# Make request to Deepgram API
|
||||
r = requests.post(
|
||||
"https://api.deepgram.com/v1/listen",
|
||||
headers=headers,
|
||||
params=params,
|
||||
data=file_data,
|
||||
)
|
||||
r.raise_for_status()
|
||||
response_data = r.json()
|
||||
|
||||
# Extract transcript from Deepgram response
|
||||
try:
|
||||
transcript = response_data["results"]["channels"][0]["alternatives"][
|
||||
0
|
||||
].get("transcript", "")
|
||||
except (KeyError, IndexError) as e:
|
||||
log.error(f"Malformed response from Deepgram: {str(e)}")
|
||||
raise Exception(
|
||||
"Failed to parse Deepgram response - unexpected response format"
|
||||
)
|
||||
data = {"text": transcript.strip()}
|
||||
|
||||
# Save transcript
|
||||
transcript_file = f"{file_dir}/{id}.json"
|
||||
with open(transcript_file, "w") as f:
|
||||
json.dump(data, f)
|
||||
|
||||
return data
|
||||
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
detail = None
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
if "error" in res:
|
||||
detail = f"External: {res['error'].get('message', '')}"
|
||||
except Exception:
|
||||
detail = f"External: {e}"
|
||||
raise Exception(detail if detail else "Open WebUI: Server Connection Error")
|
||||
|
||||
|
||||
def compress_audio(file_path):
|
||||
if os.path.getsize(file_path) > MAX_FILE_SIZE:
|
||||
|
|
@ -602,7 +679,22 @@ def transcription(
|
|||
def get_available_models(request: Request) -> list[dict]:
|
||||
available_models = []
|
||||
if request.app.state.config.TTS_ENGINE == "openai":
|
||||
available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
|
||||
# Use custom endpoint if not using the official OpenAI API URL
|
||||
if not request.app.state.config.TTS_OPENAI_API_BASE_URL.startswith(
|
||||
"https://api.openai.com"
|
||||
):
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/models"
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
available_models = data.get("models", [])
|
||||
except Exception as e:
|
||||
log.error(f"Error fetching models from custom endpoint: {str(e)}")
|
||||
available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
|
||||
else:
|
||||
available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
|
||||
elif request.app.state.config.TTS_ENGINE == "elevenlabs":
|
||||
try:
|
||||
response = requests.get(
|
||||
|
|
@ -633,14 +725,37 @@ def get_available_voices(request) -> dict:
|
|||
"""Returns {voice_id: voice_name} dict"""
|
||||
available_voices = {}
|
||||
if request.app.state.config.TTS_ENGINE == "openai":
|
||||
available_voices = {
|
||||
"alloy": "alloy",
|
||||
"echo": "echo",
|
||||
"fable": "fable",
|
||||
"onyx": "onyx",
|
||||
"nova": "nova",
|
||||
"shimmer": "shimmer",
|
||||
}
|
||||
# Use custom endpoint if not using the official OpenAI API URL
|
||||
if not request.app.state.config.TTS_OPENAI_API_BASE_URL.startswith(
|
||||
"https://api.openai.com"
|
||||
):
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/voices"
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
voices_list = data.get("voices", [])
|
||||
available_voices = {voice["id"]: voice["name"] for voice in voices_list}
|
||||
except Exception as e:
|
||||
log.error(f"Error fetching voices from custom endpoint: {str(e)}")
|
||||
available_voices = {
|
||||
"alloy": "alloy",
|
||||
"echo": "echo",
|
||||
"fable": "fable",
|
||||
"onyx": "onyx",
|
||||
"nova": "nova",
|
||||
"shimmer": "shimmer",
|
||||
}
|
||||
else:
|
||||
available_voices = {
|
||||
"alloy": "alloy",
|
||||
"echo": "echo",
|
||||
"fable": "fable",
|
||||
"onyx": "onyx",
|
||||
"nova": "nova",
|
||||
"shimmer": "shimmer",
|
||||
}
|
||||
elif request.app.state.config.TTS_ENGINE == "elevenlabs":
|
||||
try:
|
||||
available_voices = get_elevenlabs_voices(
|
||||
|
|
|
|||
|
|
@ -25,16 +25,13 @@ from open_webui.env import (
|
|||
WEBUI_AUTH,
|
||||
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
|
||||
WEBUI_AUTH_TRUSTED_NAME_HEADER,
|
||||
WEBUI_SESSION_COOKIE_SAME_SITE,
|
||||
WEBUI_SESSION_COOKIE_SECURE,
|
||||
WEBUI_AUTH_COOKIE_SAME_SITE,
|
||||
WEBUI_AUTH_COOKIE_SECURE,
|
||||
SRC_LOG_LEVELS,
|
||||
)
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from fastapi.responses import RedirectResponse, Response
|
||||
from open_webui.config import (
|
||||
OPENID_PROVIDER_URL,
|
||||
ENABLE_OAUTH_SIGNUP,
|
||||
)
|
||||
from open_webui.config import OPENID_PROVIDER_URL, ENABLE_OAUTH_SIGNUP, ENABLE_LDAP
|
||||
from pydantic import BaseModel
|
||||
from open_webui.utils.misc import parse_duration, validate_email_format
|
||||
from open_webui.utils.auth import (
|
||||
|
|
@ -51,8 +48,10 @@ from open_webui.utils.access_control import get_permissions
|
|||
from typing import Optional, List
|
||||
|
||||
from ssl import CERT_REQUIRED, PROTOCOL_TLS
|
||||
from ldap3 import Server, Connection, NONE, Tls
|
||||
from ldap3.utils.conv import escape_filter_chars
|
||||
|
||||
if ENABLE_LDAP.value:
|
||||
from ldap3 import Server, Connection, NONE, Tls
|
||||
from ldap3.utils.conv import escape_filter_chars
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
|
@ -95,8 +94,8 @@ async def get_session_user(
|
|||
value=token,
|
||||
expires=datetime_expires_at,
|
||||
httponly=True, # Ensures the cookie is not accessible via JavaScript
|
||||
samesite=WEBUI_SESSION_COOKIE_SAME_SITE,
|
||||
secure=WEBUI_SESSION_COOKIE_SECURE,
|
||||
samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
|
||||
secure=WEBUI_AUTH_COOKIE_SECURE,
|
||||
)
|
||||
|
||||
user_permissions = get_permissions(
|
||||
|
|
@ -164,7 +163,7 @@ async def update_password(
|
|||
############################
|
||||
# LDAP Authentication
|
||||
############################
|
||||
@router.post("/ldap", response_model=SigninResponse)
|
||||
@router.post("/ldap", response_model=SessionUserResponse)
|
||||
async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||
ENABLE_LDAP = request.app.state.config.ENABLE_LDAP
|
||||
LDAP_SERVER_LABEL = request.app.state.config.LDAP_SERVER_LABEL
|
||||
|
|
@ -231,9 +230,12 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
|||
|
||||
entry = connection_app.entries[0]
|
||||
username = str(entry[f"{LDAP_ATTRIBUTE_FOR_USERNAME}"]).lower()
|
||||
mail = str(entry[f"{LDAP_ATTRIBUTE_FOR_MAIL}"])
|
||||
if not mail or mail == "" or mail == "[]":
|
||||
raise HTTPException(400, f"User {form_data.user} does not have mail.")
|
||||
email = str(entry[f"{LDAP_ATTRIBUTE_FOR_MAIL}"])
|
||||
if not email or email == "" or email == "[]":
|
||||
raise HTTPException(400, f"User {form_data.user} does not have email.")
|
||||
else:
|
||||
email = email.lower()
|
||||
|
||||
cn = str(entry["cn"])
|
||||
user_dn = entry.entry_dn
|
||||
|
||||
|
|
@ -248,17 +250,22 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
|||
if not connection_user.bind():
|
||||
raise HTTPException(400, f"Authentication failed for {form_data.user}")
|
||||
|
||||
user = Users.get_user_by_email(mail)
|
||||
user = Users.get_user_by_email(email)
|
||||
if not user:
|
||||
try:
|
||||
user_count = Users.get_num_users()
|
||||
|
||||
role = (
|
||||
"admin"
|
||||
if Users.get_num_users() == 0
|
||||
if user_count == 0
|
||||
else request.app.state.config.DEFAULT_USER_ROLE
|
||||
)
|
||||
|
||||
user = Auths.insert_new_auth(
|
||||
email=mail, password=str(uuid.uuid4()), name=cn, role=role
|
||||
email=email,
|
||||
password=str(uuid.uuid4()),
|
||||
name=cn,
|
||||
role=role,
|
||||
)
|
||||
|
||||
if not user:
|
||||
|
|
@ -271,7 +278,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
|||
except Exception as err:
|
||||
raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
|
||||
|
||||
user = Auths.authenticate_user_by_trusted_header(mail)
|
||||
user = Auths.authenticate_user_by_trusted_header(email)
|
||||
|
||||
if user:
|
||||
token = create_token(
|
||||
|
|
@ -288,6 +295,10 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
|||
httponly=True, # Ensures the cookie is not accessible via JavaScript
|
||||
)
|
||||
|
||||
user_permissions = get_permissions(
|
||||
user.id, request.app.state.config.USER_PERMISSIONS
|
||||
)
|
||||
|
||||
return {
|
||||
"token": token,
|
||||
"token_type": "Bearer",
|
||||
|
|
@ -296,6 +307,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
|||
"name": user.name,
|
||||
"role": user.role,
|
||||
"profile_image_url": user.profile_image_url,
|
||||
"permissions": user_permissions,
|
||||
}
|
||||
else:
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||
|
|
@ -378,8 +390,8 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
|
|||
value=token,
|
||||
expires=datetime_expires_at,
|
||||
httponly=True, # Ensures the cookie is not accessible via JavaScript
|
||||
samesite=WEBUI_SESSION_COOKIE_SAME_SITE,
|
||||
secure=WEBUI_SESSION_COOKIE_SECURE,
|
||||
samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
|
||||
secure=WEBUI_AUTH_COOKIE_SECURE,
|
||||
)
|
||||
|
||||
user_permissions = get_permissions(
|
||||
|
|
@ -408,6 +420,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
|
|||
|
||||
@router.post("/signup", response_model=SessionUserResponse)
|
||||
async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||
|
||||
if WEBUI_AUTH:
|
||||
if (
|
||||
not request.app.state.config.ENABLE_SIGNUP
|
||||
|
|
@ -422,6 +435,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
|||
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
|
||||
)
|
||||
|
||||
user_count = Users.get_num_users()
|
||||
if not validate_email_format(form_data.email.lower()):
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
|
||||
|
|
@ -432,12 +446,10 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
|||
|
||||
try:
|
||||
role = (
|
||||
"admin"
|
||||
if Users.get_num_users() == 0
|
||||
else request.app.state.config.DEFAULT_USER_ROLE
|
||||
"admin" if user_count == 0 else request.app.state.config.DEFAULT_USER_ROLE
|
||||
)
|
||||
|
||||
if Users.get_num_users() == 0:
|
||||
if user_count == 0:
|
||||
# Disable signup after the first user is created
|
||||
request.app.state.config.ENABLE_SIGNUP = False
|
||||
|
||||
|
|
@ -473,12 +485,13 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
|||
value=token,
|
||||
expires=datetime_expires_at,
|
||||
httponly=True, # Ensures the cookie is not accessible via JavaScript
|
||||
samesite=WEBUI_SESSION_COOKIE_SAME_SITE,
|
||||
secure=WEBUI_SESSION_COOKIE_SECURE,
|
||||
samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
|
||||
secure=WEBUI_AUTH_COOKIE_SECURE,
|
||||
)
|
||||
|
||||
if request.app.state.config.WEBHOOK_URL:
|
||||
post_webhook(
|
||||
request.app.state.WEBUI_NAME,
|
||||
request.app.state.config.WEBHOOK_URL,
|
||||
WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
|
||||
{
|
||||
|
|
@ -525,7 +538,8 @@ async def signout(request: Request, response: Response):
|
|||
if logout_url:
|
||||
response.delete_cookie("oauth_id_token")
|
||||
return RedirectResponse(
|
||||
url=f"{logout_url}?id_token_hint={oauth_id_token}"
|
||||
headers=response.headers,
|
||||
url=f"{logout_url}?id_token_hint={oauth_id_token}",
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
|
|
@ -591,7 +605,7 @@ async def get_admin_details(request: Request, user=Depends(get_current_user)):
|
|||
admin_email = request.app.state.config.ADMIN_EMAIL
|
||||
admin_name = None
|
||||
|
||||
print(admin_email, admin_name)
|
||||
log.info(f"Admin details - Email: {admin_email}, Name: {admin_name}")
|
||||
|
||||
if admin_email:
|
||||
admin = Users.get_user_by_email(admin_email)
|
||||
|
|
|
|||
|
|
@ -192,7 +192,7 @@ async def get_channel_messages(
|
|||
############################
|
||||
|
||||
|
||||
async def send_notification(webui_url, channel, message, active_user_ids):
|
||||
async def send_notification(name, webui_url, channel, message, active_user_ids):
|
||||
users = get_users_with_access("read", channel.access_control)
|
||||
|
||||
for user in users:
|
||||
|
|
@ -206,6 +206,7 @@ async def send_notification(webui_url, channel, message, active_user_ids):
|
|||
|
||||
if webhook_url:
|
||||
post_webhook(
|
||||
name,
|
||||
webhook_url,
|
||||
f"#{channel.name} - {webui_url}/channels/{channel.id}\n\n{message.content}",
|
||||
{
|
||||
|
|
@ -302,6 +303,7 @@ async def post_new_message(
|
|||
|
||||
background_tasks.add_task(
|
||||
send_notification,
|
||||
request.app.state.WEBUI_NAME,
|
||||
request.app.state.config.WEBUI_URL,
|
||||
channel,
|
||||
message,
|
||||
|
|
|
|||
|
|
@ -444,15 +444,21 @@ async def pin_chat_by_id(id: str, user=Depends(get_verified_user)):
|
|||
############################
|
||||
|
||||
|
||||
class CloneForm(BaseModel):
|
||||
title: Optional[str] = None
|
||||
|
||||
|
||||
@router.post("/{id}/clone", response_model=Optional[ChatResponse])
|
||||
async def clone_chat_by_id(id: str, user=Depends(get_verified_user)):
|
||||
async def clone_chat_by_id(
|
||||
form_data: CloneForm, id: str, user=Depends(get_verified_user)
|
||||
):
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
if chat:
|
||||
updated_chat = {
|
||||
**chat.chat,
|
||||
"originalChatId": chat.id,
|
||||
"branchPointMessageId": chat.chat["history"]["currentId"],
|
||||
"title": f"Clone of {chat.title}",
|
||||
"title": form_data.title if form_data.title else f"Clone of {chat.title}",
|
||||
}
|
||||
|
||||
chat = Chats.insert_new_chat(user.id, ChatForm(**{"chat": updated_chat}))
|
||||
|
|
|
|||
|
|
@ -36,6 +36,140 @@ async def export_config(user=Depends(get_admin_user)):
|
|||
return get_config()
|
||||
|
||||
|
||||
############################
|
||||
# Direct Connections Config
|
||||
############################
|
||||
|
||||
|
||||
class DirectConnectionsConfigForm(BaseModel):
|
||||
ENABLE_DIRECT_CONNECTIONS: bool
|
||||
|
||||
|
||||
@router.get("/direct_connections", response_model=DirectConnectionsConfigForm)
|
||||
async def get_direct_connections_config(request: Request, user=Depends(get_admin_user)):
|
||||
return {
|
||||
"ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/direct_connections", response_model=DirectConnectionsConfigForm)
|
||||
async def set_direct_connections_config(
|
||||
request: Request,
|
||||
form_data: DirectConnectionsConfigForm,
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
request.app.state.config.ENABLE_DIRECT_CONNECTIONS = (
|
||||
form_data.ENABLE_DIRECT_CONNECTIONS
|
||||
)
|
||||
return {
|
||||
"ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS,
|
||||
}
|
||||
|
||||
|
||||
############################
|
||||
# CodeInterpreterConfig
|
||||
############################
|
||||
class CodeInterpreterConfigForm(BaseModel):
|
||||
CODE_EXECUTION_ENGINE: str
|
||||
CODE_EXECUTION_JUPYTER_URL: Optional[str]
|
||||
CODE_EXECUTION_JUPYTER_AUTH: Optional[str]
|
||||
CODE_EXECUTION_JUPYTER_AUTH_TOKEN: Optional[str]
|
||||
CODE_EXECUTION_JUPYTER_AUTH_PASSWORD: Optional[str]
|
||||
CODE_EXECUTION_JUPYTER_TIMEOUT: Optional[int]
|
||||
ENABLE_CODE_INTERPRETER: bool
|
||||
CODE_INTERPRETER_ENGINE: str
|
||||
CODE_INTERPRETER_PROMPT_TEMPLATE: Optional[str]
|
||||
CODE_INTERPRETER_JUPYTER_URL: Optional[str]
|
||||
CODE_INTERPRETER_JUPYTER_AUTH: Optional[str]
|
||||
CODE_INTERPRETER_JUPYTER_AUTH_TOKEN: Optional[str]
|
||||
CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD: Optional[str]
|
||||
CODE_INTERPRETER_JUPYTER_TIMEOUT: Optional[int]
|
||||
|
||||
|
||||
@router.get("/code_execution", response_model=CodeInterpreterConfigForm)
|
||||
async def get_code_execution_config(request: Request, user=Depends(get_admin_user)):
|
||||
return {
|
||||
"CODE_EXECUTION_ENGINE": request.app.state.config.CODE_EXECUTION_ENGINE,
|
||||
"CODE_EXECUTION_JUPYTER_URL": request.app.state.config.CODE_EXECUTION_JUPYTER_URL,
|
||||
"CODE_EXECUTION_JUPYTER_AUTH": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH,
|
||||
"CODE_EXECUTION_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN,
|
||||
"CODE_EXECUTION_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD,
|
||||
"CODE_EXECUTION_JUPYTER_TIMEOUT": request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT,
|
||||
"ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER,
|
||||
"CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE,
|
||||
"CODE_INTERPRETER_PROMPT_TEMPLATE": request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE,
|
||||
"CODE_INTERPRETER_JUPYTER_URL": request.app.state.config.CODE_INTERPRETER_JUPYTER_URL,
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH,
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN,
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD,
|
||||
"CODE_INTERPRETER_JUPYTER_TIMEOUT": request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/code_execution", response_model=CodeInterpreterConfigForm)
|
||||
async def set_code_execution_config(
|
||||
request: Request, form_data: CodeInterpreterConfigForm, user=Depends(get_admin_user)
|
||||
):
|
||||
|
||||
request.app.state.config.CODE_EXECUTION_ENGINE = form_data.CODE_EXECUTION_ENGINE
|
||||
request.app.state.config.CODE_EXECUTION_JUPYTER_URL = (
|
||||
form_data.CODE_EXECUTION_JUPYTER_URL
|
||||
)
|
||||
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH = (
|
||||
form_data.CODE_EXECUTION_JUPYTER_AUTH
|
||||
)
|
||||
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN = (
|
||||
form_data.CODE_EXECUTION_JUPYTER_AUTH_TOKEN
|
||||
)
|
||||
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD = (
|
||||
form_data.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD
|
||||
)
|
||||
request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT = (
|
||||
form_data.CODE_EXECUTION_JUPYTER_TIMEOUT
|
||||
)
|
||||
|
||||
request.app.state.config.ENABLE_CODE_INTERPRETER = form_data.ENABLE_CODE_INTERPRETER
|
||||
request.app.state.config.CODE_INTERPRETER_ENGINE = form_data.CODE_INTERPRETER_ENGINE
|
||||
request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE = (
|
||||
form_data.CODE_INTERPRETER_PROMPT_TEMPLATE
|
||||
)
|
||||
|
||||
request.app.state.config.CODE_INTERPRETER_JUPYTER_URL = (
|
||||
form_data.CODE_INTERPRETER_JUPYTER_URL
|
||||
)
|
||||
|
||||
request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH = (
|
||||
form_data.CODE_INTERPRETER_JUPYTER_AUTH
|
||||
)
|
||||
|
||||
request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN = (
|
||||
form_data.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN
|
||||
)
|
||||
request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = (
|
||||
form_data.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD
|
||||
)
|
||||
request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT = (
|
||||
form_data.CODE_INTERPRETER_JUPYTER_TIMEOUT
|
||||
)
|
||||
|
||||
return {
|
||||
"CODE_EXECUTION_ENGINE": request.app.state.config.CODE_EXECUTION_ENGINE,
|
||||
"CODE_EXECUTION_JUPYTER_URL": request.app.state.config.CODE_EXECUTION_JUPYTER_URL,
|
||||
"CODE_EXECUTION_JUPYTER_AUTH": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH,
|
||||
"CODE_EXECUTION_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN,
|
||||
"CODE_EXECUTION_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD,
|
||||
"CODE_EXECUTION_JUPYTER_TIMEOUT": request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT,
|
||||
"ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER,
|
||||
"CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE,
|
||||
"CODE_INTERPRETER_PROMPT_TEMPLATE": request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE,
|
||||
"CODE_INTERPRETER_JUPYTER_URL": request.app.state.config.CODE_INTERPRETER_JUPYTER_URL,
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH,
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN,
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD,
|
||||
"CODE_INTERPRETER_JUPYTER_TIMEOUT": request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT,
|
||||
}
|
||||
|
||||
|
||||
############################
|
||||
# SetDefaultModels
|
||||
############################
|
||||
|
|
|
|||
|
|
@ -3,30 +3,23 @@ import os
|
|||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
import mimetypes
|
||||
from urllib.parse import quote
|
||||
|
||||
from open_webui.storage.provider import Storage
|
||||
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile, status
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from open_webui.models.files import (
|
||||
FileForm,
|
||||
FileModel,
|
||||
FileModelResponse,
|
||||
Files,
|
||||
)
|
||||
from open_webui.routers.retrieval import process_file, ProcessFileForm
|
||||
|
||||
from open_webui.config import UPLOAD_DIR
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
|
||||
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status, Request
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
|
||||
|
||||
from open_webui.routers.retrieval import ProcessFileForm, process_file
|
||||
from open_webui.routers.audio import transcribe
|
||||
from open_webui.storage.provider import Storage
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from pydantic import BaseModel
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
|
|
@ -41,7 +34,10 @@ router = APIRouter()
|
|||
|
||||
@router.post("/", response_model=FileModelResponse)
|
||||
def upload_file(
|
||||
request: Request, file: UploadFile = File(...), user=Depends(get_verified_user)
|
||||
request: Request,
|
||||
file: UploadFile = File(...),
|
||||
user=Depends(get_verified_user),
|
||||
file_metadata: dict = {},
|
||||
):
|
||||
log.info(f"file.content_type: {file.content_type}")
|
||||
try:
|
||||
|
|
@ -65,13 +61,29 @@ def upload_file(
|
|||
"name": name,
|
||||
"content_type": file.content_type,
|
||||
"size": len(contents),
|
||||
"data": file_metadata,
|
||||
},
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
process_file(request, ProcessFileForm(file_id=id))
|
||||
if file.content_type in [
|
||||
"audio/mpeg",
|
||||
"audio/wav",
|
||||
"audio/ogg",
|
||||
"audio/x-m4a",
|
||||
]:
|
||||
file_path = Storage.get_file(file_path)
|
||||
result = transcribe(request, file_path)
|
||||
process_file(
|
||||
request,
|
||||
ProcessFileForm(file_id=id, content=result.get("text", "")),
|
||||
user=user,
|
||||
)
|
||||
else:
|
||||
process_file(request, ProcessFileForm(file_id=id), user=user)
|
||||
|
||||
file_item = Files.get_file_by_id(id=id)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
|
@ -126,7 +138,7 @@ async def delete_all_files(user=Depends(get_admin_user)):
|
|||
Storage.delete_all_files()
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
log.error(f"Error deleting files")
|
||||
log.error("Error deleting files")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error deleting files"),
|
||||
|
|
@ -193,7 +205,9 @@ async def update_file_data_content_by_id(
|
|||
if file and (file.user_id == user.id or user.role == "admin"):
|
||||
try:
|
||||
process_file(
|
||||
request, ProcessFileForm(file_id=id, content=form_data.content)
|
||||
request,
|
||||
ProcessFileForm(file_id=id, content=form_data.content),
|
||||
user=user,
|
||||
)
|
||||
file = Files.get_file_by_id(id=id)
|
||||
except Exception as e:
|
||||
|
|
@ -227,17 +241,24 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
|||
filename = file.meta.get("name", file.filename)
|
||||
encoded_filename = quote(filename) # RFC5987 encoding
|
||||
|
||||
content_type = file.meta.get("content_type")
|
||||
filename = file.meta.get("name", file.filename)
|
||||
encoded_filename = quote(filename)
|
||||
headers = {}
|
||||
if file.meta.get("content_type") not in [
|
||||
"application/pdf",
|
||||
"text/plain",
|
||||
]:
|
||||
headers = {
|
||||
**headers,
|
||||
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}",
|
||||
}
|
||||
|
||||
return FileResponse(file_path, headers=headers)
|
||||
if content_type == "application/pdf" or filename.lower().endswith(
|
||||
".pdf"
|
||||
):
|
||||
headers["Content-Disposition"] = (
|
||||
f"inline; filename*=UTF-8''{encoded_filename}"
|
||||
)
|
||||
content_type = "application/pdf"
|
||||
elif content_type != "text/plain":
|
||||
headers["Content-Disposition"] = (
|
||||
f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||
)
|
||||
|
||||
return FileResponse(file_path, headers=headers, media_type=content_type)
|
||||
|
||||
else:
|
||||
raise HTTPException(
|
||||
|
|
@ -246,7 +267,7 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
|||
)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
log.error(f"Error getting file content")
|
||||
log.error("Error getting file content")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error getting file content"),
|
||||
|
|
@ -268,7 +289,7 @@ async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
|||
|
||||
# Check if the file already exists in the cache
|
||||
if file_path.is_file():
|
||||
print(f"file_path: {file_path}")
|
||||
log.info(f"file_path: {file_path}")
|
||||
return FileResponse(file_path)
|
||||
else:
|
||||
raise HTTPException(
|
||||
|
|
@ -277,7 +298,7 @@ async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
|||
)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
log.error(f"Error getting file content")
|
||||
log.error("Error getting file content")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error getting file content"),
|
||||
|
|
@ -353,7 +374,7 @@ async def delete_file_by_id(id: str, user=Depends(get_verified_user)):
|
|||
Storage.delete_file(file.path)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
log.error(f"Error deleting files")
|
||||
log.error("Error deleting files")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error deleting files"),
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
|
@ -13,6 +14,11 @@ from open_webui.config import CACHE_DIR
|
|||
from open_webui.constants import ERROR_MESSAGES
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
|
@ -68,7 +74,7 @@ async def create_new_function(
|
|||
|
||||
function = Functions.insert_new_function(user.id, function_type, form_data)
|
||||
|
||||
function_cache_dir = Path(CACHE_DIR) / "functions" / form_data.id
|
||||
function_cache_dir = CACHE_DIR / "functions" / form_data.id
|
||||
function_cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if function:
|
||||
|
|
@ -79,7 +85,7 @@ async def create_new_function(
|
|||
detail=ERROR_MESSAGES.DEFAULT("Error creating function"),
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Failed to create a new function: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
|
|
@ -183,7 +189,7 @@ async def update_function_by_id(
|
|||
FUNCTIONS[id] = function_module
|
||||
|
||||
updated = {**form_data.model_dump(exclude={"id"}), "type": function_type}
|
||||
print(updated)
|
||||
log.debug(updated)
|
||||
|
||||
function = Functions.update_function_by_id(id, updated)
|
||||
|
||||
|
|
@ -299,7 +305,7 @@ async def update_function_valves_by_id(
|
|||
Functions.update_function_valves_by_id(id, valves.model_dump())
|
||||
return valves.model_dump()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Error updating function values by id {id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
|
|
@ -388,7 +394,7 @@ async def update_function_user_valves_by_id(
|
|||
)
|
||||
return user_valves.model_dump()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Error updating function user valves by id {id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import logging
|
||||
|
||||
from open_webui.models.users import Users
|
||||
from open_webui.models.groups import (
|
||||
|
|
@ -14,7 +14,13 @@ from open_webui.models.groups import (
|
|||
from open_webui.config import CACHE_DIR
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
|
@ -37,7 +43,7 @@ async def get_groups(user=Depends(get_verified_user)):
|
|||
|
||||
|
||||
@router.post("/create", response_model=Optional[GroupResponse])
|
||||
async def create_new_function(form_data: GroupForm, user=Depends(get_admin_user)):
|
||||
async def create_new_group(form_data: GroupForm, user=Depends(get_admin_user)):
|
||||
try:
|
||||
group = Groups.insert_new_group(user.id, form_data)
|
||||
if group:
|
||||
|
|
@ -48,7 +54,7 @@ async def create_new_function(form_data: GroupForm, user=Depends(get_admin_user)
|
|||
detail=ERROR_MESSAGES.DEFAULT("Error creating group"),
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Error creating a new group: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
|
|
@ -94,7 +100,7 @@ async def update_group_by_id(
|
|||
detail=ERROR_MESSAGES.DEFAULT("Error updating group"),
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Error updating group {id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
|
|
@ -118,7 +124,7 @@ async def delete_group_by_id(id: str, user=Depends(get_admin_user)):
|
|||
detail=ERROR_MESSAGES.DEFAULT("Error deleting group"),
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Error deleting group {id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
|
|
|
|||
|
|
@ -1,37 +1,31 @@
|
|||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
import re
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
from fastapi import Depends, FastAPI, HTTPException, Request, APIRouter
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile
|
||||
from open_webui.config import CACHE_DIR
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import ENV, SRC_LOG_LEVELS, ENABLE_FORWARD_USER_INFO_HEADERS
|
||||
|
||||
from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS
|
||||
from open_webui.routers.files import upload_file
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.utils.images.comfyui import (
|
||||
ComfyUIGenerateImageForm,
|
||||
ComfyUIWorkflow,
|
||||
comfyui_generate_image,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["IMAGES"])
|
||||
|
||||
IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
|
||||
IMAGE_CACHE_DIR = CACHE_DIR / "image" / "generations"
|
||||
IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
|
|
@ -61,6 +55,10 @@ async def get_config(request: Request, user=Depends(get_admin_user)):
|
|||
"COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
|
||||
"COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
|
||||
},
|
||||
"gemini": {
|
||||
"GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
|
||||
"GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -84,6 +82,11 @@ class ComfyUIConfigForm(BaseModel):
|
|||
COMFYUI_WORKFLOW_NODES: list[dict]
|
||||
|
||||
|
||||
class GeminiConfigForm(BaseModel):
|
||||
GEMINI_API_BASE_URL: str
|
||||
GEMINI_API_KEY: str
|
||||
|
||||
|
||||
class ConfigForm(BaseModel):
|
||||
enabled: bool
|
||||
engine: str
|
||||
|
|
@ -91,6 +94,7 @@ class ConfigForm(BaseModel):
|
|||
openai: OpenAIConfigForm
|
||||
automatic1111: Automatic1111ConfigForm
|
||||
comfyui: ComfyUIConfigForm
|
||||
gemini: GeminiConfigForm
|
||||
|
||||
|
||||
@router.post("/config/update")
|
||||
|
|
@ -109,6 +113,11 @@ async def update_config(
|
|||
)
|
||||
request.app.state.config.IMAGES_OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY
|
||||
|
||||
request.app.state.config.IMAGES_GEMINI_API_BASE_URL = (
|
||||
form_data.gemini.GEMINI_API_BASE_URL
|
||||
)
|
||||
request.app.state.config.IMAGES_GEMINI_API_KEY = form_data.gemini.GEMINI_API_KEY
|
||||
|
||||
request.app.state.config.AUTOMATIC1111_BASE_URL = (
|
||||
form_data.automatic1111.AUTOMATIC1111_BASE_URL
|
||||
)
|
||||
|
|
@ -135,6 +144,8 @@ async def update_config(
|
|||
request.app.state.config.COMFYUI_BASE_URL = (
|
||||
form_data.comfyui.COMFYUI_BASE_URL.strip("/")
|
||||
)
|
||||
request.app.state.config.COMFYUI_API_KEY = form_data.comfyui.COMFYUI_API_KEY
|
||||
|
||||
request.app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW
|
||||
request.app.state.config.COMFYUI_WORKFLOW_NODES = (
|
||||
form_data.comfyui.COMFYUI_WORKFLOW_NODES
|
||||
|
|
@ -161,6 +172,10 @@ async def update_config(
|
|||
"COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
|
||||
"COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
|
||||
},
|
||||
"gemini": {
|
||||
"GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
|
||||
"GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -190,9 +205,17 @@ async def verify_url(request: Request, user=Depends(get_admin_user)):
|
|||
request.app.state.config.ENABLE_IMAGE_GENERATION = False
|
||||
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
|
||||
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
|
||||
|
||||
headers = None
|
||||
if request.app.state.config.COMFYUI_API_KEY:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
|
||||
}
|
||||
|
||||
try:
|
||||
r = requests.get(
|
||||
url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info"
|
||||
url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info",
|
||||
headers=headers,
|
||||
)
|
||||
r.raise_for_status()
|
||||
return True
|
||||
|
|
@ -230,6 +253,12 @@ def get_image_model(request):
|
|||
if request.app.state.config.IMAGE_GENERATION_MODEL
|
||||
else "dall-e-2"
|
||||
)
|
||||
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
|
||||
return (
|
||||
request.app.state.config.IMAGE_GENERATION_MODEL
|
||||
if request.app.state.config.IMAGE_GENERATION_MODEL
|
||||
else "imagen-3.0-generate-002"
|
||||
)
|
||||
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
|
||||
return (
|
||||
request.app.state.config.IMAGE_GENERATION_MODEL
|
||||
|
|
@ -271,7 +300,6 @@ async def get_image_config(request: Request, user=Depends(get_admin_user)):
|
|||
async def update_image_config(
|
||||
request: Request, form_data: ImageConfigForm, user=Depends(get_admin_user)
|
||||
):
|
||||
|
||||
set_image_model(request, form_data.MODEL)
|
||||
|
||||
pattern = r"^\d+x\d+$"
|
||||
|
|
@ -306,6 +334,10 @@ def get_models(request: Request, user=Depends(get_verified_user)):
|
|||
{"id": "dall-e-2", "name": "DALL·E 2"},
|
||||
{"id": "dall-e-3", "name": "DALL·E 3"},
|
||||
]
|
||||
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
|
||||
return [
|
||||
{"id": "imagen-3-0-generate-002", "name": "imagen-3.0 generate-002"},
|
||||
]
|
||||
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
|
||||
# TODO - get models from comfyui
|
||||
headers = {
|
||||
|
|
@ -329,7 +361,7 @@ def get_models(request: Request, user=Depends(get_verified_user)):
|
|||
if model_node_id:
|
||||
model_list_key = None
|
||||
|
||||
print(workflow[model_node_id]["class_type"])
|
||||
log.info(workflow[model_node_id]["class_type"])
|
||||
for key in info[workflow[model_node_id]["class_type"]]["input"][
|
||||
"required"
|
||||
]:
|
||||
|
|
@ -383,40 +415,22 @@ class GenerateImageForm(BaseModel):
|
|||
negative_prompt: Optional[str] = None
|
||||
|
||||
|
||||
def save_b64_image(b64_str):
|
||||
def load_b64_image_data(b64_str):
|
||||
try:
|
||||
image_id = str(uuid.uuid4())
|
||||
|
||||
if "," in b64_str:
|
||||
header, encoded = b64_str.split(",", 1)
|
||||
mime_type = header.split(";")[0]
|
||||
|
||||
img_data = base64.b64decode(encoded)
|
||||
image_format = mimetypes.guess_extension(mime_type)
|
||||
|
||||
image_filename = f"{image_id}{image_format}"
|
||||
file_path = IMAGE_CACHE_DIR / f"{image_filename}"
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(img_data)
|
||||
return image_filename
|
||||
else:
|
||||
image_filename = f"{image_id}.png"
|
||||
file_path = IMAGE_CACHE_DIR.joinpath(image_filename)
|
||||
|
||||
mime_type = "image/png"
|
||||
img_data = base64.b64decode(b64_str)
|
||||
|
||||
# Write the image data to a file
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(img_data)
|
||||
return image_filename
|
||||
|
||||
return img_data, mime_type
|
||||
except Exception as e:
|
||||
log.exception(f"Error saving image: {e}")
|
||||
log.exception(f"Error loading image data: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def save_url_image(url, headers=None):
|
||||
image_id = str(uuid.uuid4())
|
||||
def load_url_image_data(url, headers=None):
|
||||
try:
|
||||
if headers:
|
||||
r = requests.get(url, headers=headers)
|
||||
|
|
@ -426,18 +440,7 @@ def save_url_image(url, headers=None):
|
|||
r.raise_for_status()
|
||||
if r.headers["content-type"].split("/")[0] == "image":
|
||||
mime_type = r.headers["content-type"]
|
||||
image_format = mimetypes.guess_extension(mime_type)
|
||||
|
||||
if not image_format:
|
||||
raise ValueError("Could not determine image type from MIME type")
|
||||
|
||||
image_filename = f"{image_id}{image_format}"
|
||||
|
||||
file_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}")
|
||||
with open(file_path, "wb") as image_file:
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
image_file.write(chunk)
|
||||
return image_filename
|
||||
return r.content, mime_type
|
||||
else:
|
||||
log.error("Url does not point to an image.")
|
||||
return None
|
||||
|
|
@ -447,6 +450,20 @@ def save_url_image(url, headers=None):
|
|||
return None
|
||||
|
||||
|
||||
def upload_image(request, image_metadata, image_data, content_type, user):
|
||||
image_format = mimetypes.guess_extension(content_type)
|
||||
file = UploadFile(
|
||||
file=io.BytesIO(image_data),
|
||||
filename=f"generated-image{image_format}", # will be converted to a unique ID on upload_file
|
||||
headers={
|
||||
"content-type": content_type,
|
||||
},
|
||||
)
|
||||
file_item = upload_file(request, file, user, file_metadata=image_metadata)
|
||||
url = request.app.url_path_for("get_file_content_by_id", id=file_item.id)
|
||||
return url
|
||||
|
||||
|
||||
@router.post("/generations")
|
||||
async def image_generations(
|
||||
request: Request,
|
||||
|
|
@ -500,12 +517,49 @@ async def image_generations(
|
|||
images = []
|
||||
|
||||
for image in res["data"]:
|
||||
image_filename = save_b64_image(image["b64_json"])
|
||||
images.append({"url": f"/cache/image/generations/{image_filename}"})
|
||||
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
|
||||
if "url" in image:
|
||||
image_data, content_type = load_url_image_data(
|
||||
image["url"], headers
|
||||
)
|
||||
else:
|
||||
image_data, content_type = load_b64_image_data(image["b64_json"])
|
||||
|
||||
with open(file_body_path, "w") as f:
|
||||
json.dump(data, f)
|
||||
url = upload_image(request, data, image_data, content_type, user)
|
||||
images.append({"url": url})
|
||||
return images
|
||||
|
||||
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
|
||||
headers = {}
|
||||
headers["Content-Type"] = "application/json"
|
||||
headers["x-goog-api-key"] = request.app.state.config.IMAGES_GEMINI_API_KEY
|
||||
|
||||
model = get_image_model(request)
|
||||
data = {
|
||||
"instances": {"prompt": form_data.prompt},
|
||||
"parameters": {
|
||||
"sampleCount": form_data.n,
|
||||
"outputOptions": {"mimeType": "image/png"},
|
||||
},
|
||||
}
|
||||
|
||||
# Use asyncio.to_thread for the requests.post call
|
||||
r = await asyncio.to_thread(
|
||||
requests.post,
|
||||
url=f"{request.app.state.config.IMAGES_GEMINI_API_BASE_URL}/models/{model}:predict",
|
||||
json=data,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
res = r.json()
|
||||
|
||||
images = []
|
||||
for image in res["predictions"]:
|
||||
image_data, content_type = load_b64_image_data(
|
||||
image["bytesBase64Encoded"]
|
||||
)
|
||||
url = upload_image(request, data, image_data, content_type, user)
|
||||
images.append({"url": url})
|
||||
|
||||
return images
|
||||
|
||||
|
|
@ -552,14 +606,15 @@ async def image_generations(
|
|||
"Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
|
||||
}
|
||||
|
||||
image_filename = save_url_image(image["url"], headers)
|
||||
images.append({"url": f"/cache/image/generations/{image_filename}"})
|
||||
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
|
||||
|
||||
with open(file_body_path, "w") as f:
|
||||
json.dump(form_data.model_dump(exclude_none=True), f)
|
||||
|
||||
log.debug(f"images: {images}")
|
||||
image_data, content_type = load_url_image_data(image["url"], headers)
|
||||
url = upload_image(
|
||||
request,
|
||||
form_data.model_dump(exclude_none=True),
|
||||
image_data,
|
||||
content_type,
|
||||
user,
|
||||
)
|
||||
images.append({"url": url})
|
||||
return images
|
||||
elif (
|
||||
request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
|
||||
|
|
@ -604,13 +659,15 @@ async def image_generations(
|
|||
images = []
|
||||
|
||||
for image in res["images"]:
|
||||
image_filename = save_b64_image(image)
|
||||
images.append({"url": f"/cache/image/generations/{image_filename}"})
|
||||
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
|
||||
|
||||
with open(file_body_path, "w") as f:
|
||||
json.dump({**data, "info": res["info"]}, f)
|
||||
|
||||
image_data, content_type = load_b64_image_data(image)
|
||||
url = upload_image(
|
||||
request,
|
||||
{**data, "info": res["info"]},
|
||||
image_data,
|
||||
content_type,
|
||||
user,
|
||||
)
|
||||
images.append({"url": url})
|
||||
return images
|
||||
except Exception as e:
|
||||
error = e
|
||||
|
|
|
|||
|
|
@ -264,7 +264,11 @@ def add_file_to_knowledge_by_id(
|
|||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
if knowledge.user_id != user.id and user.role != "admin":
|
||||
if (
|
||||
knowledge.user_id != user.id
|
||||
and not has_access(user.id, "write", knowledge.access_control)
|
||||
and user.role != "admin"
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
|
|
@ -285,7 +289,9 @@ def add_file_to_knowledge_by_id(
|
|||
# Add content to the vector database
|
||||
try:
|
||||
process_file(
|
||||
request, ProcessFileForm(file_id=form_data.file_id, collection_name=id)
|
||||
request,
|
||||
ProcessFileForm(file_id=form_data.file_id, collection_name=id),
|
||||
user=user,
|
||||
)
|
||||
except Exception as e:
|
||||
log.debug(e)
|
||||
|
|
@ -342,7 +348,12 @@ def update_file_from_knowledge_by_id(
|
|||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
if knowledge.user_id != user.id and user.role != "admin":
|
||||
if (
|
||||
knowledge.user_id != user.id
|
||||
and not has_access(user.id, "write", knowledge.access_control)
|
||||
and user.role != "admin"
|
||||
):
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
|
|
@ -363,7 +374,9 @@ def update_file_from_knowledge_by_id(
|
|||
# Add content to the vector database
|
||||
try:
|
||||
process_file(
|
||||
request, ProcessFileForm(file_id=form_data.file_id, collection_name=id)
|
||||
request,
|
||||
ProcessFileForm(file_id=form_data.file_id, collection_name=id),
|
||||
user=user,
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
|
|
@ -406,7 +419,11 @@ def remove_file_from_knowledge_by_id(
|
|||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
if knowledge.user_id != user.id and user.role != "admin":
|
||||
if (
|
||||
knowledge.user_id != user.id
|
||||
and not has_access(user.id, "write", knowledge.access_control)
|
||||
and user.role != "admin"
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
|
|
@ -429,10 +446,6 @@ def remove_file_from_knowledge_by_id(
|
|||
if VECTOR_DB_CLIENT.has_collection(collection_name=file_collection):
|
||||
VECTOR_DB_CLIENT.delete_collection(collection_name=file_collection)
|
||||
|
||||
# Delete physical file
|
||||
if file.path:
|
||||
Storage.delete_file(file.path)
|
||||
|
||||
# Delete file from database
|
||||
Files.delete_file_by_id(form_data.file_id)
|
||||
|
||||
|
|
@ -484,7 +497,11 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)):
|
|||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
if knowledge.user_id != user.id and user.role != "admin":
|
||||
if (
|
||||
knowledge.user_id != user.id
|
||||
and not has_access(user.id, "write", knowledge.access_control)
|
||||
and user.role != "admin"
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
|
|
@ -543,7 +560,11 @@ async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user)):
|
|||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
if knowledge.user_id != user.id and user.role != "admin":
|
||||
if (
|
||||
knowledge.user_id != user.id
|
||||
and not has_access(user.id, "write", knowledge.access_control)
|
||||
and user.role != "admin"
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
|
|
@ -582,14 +603,18 @@ def add_files_to_knowledge_batch(
|
|||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
if knowledge.user_id != user.id and user.role != "admin":
|
||||
if (
|
||||
knowledge.user_id != user.id
|
||||
and not has_access(user.id, "write", knowledge.access_control)
|
||||
and user.role != "admin"
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
|
||||
# Get files content
|
||||
print(f"files/batch/add - {len(form_data)} files")
|
||||
log.info(f"files/batch/add - {len(form_data)} files")
|
||||
files: List[FileModel] = []
|
||||
for form in form_data:
|
||||
file = Files.get_file_by_id(form.file_id)
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ async def add_memory(
|
|||
{
|
||||
"id": memory.id,
|
||||
"text": memory.content,
|
||||
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content),
|
||||
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content, user),
|
||||
"metadata": {"created_at": memory.created_at},
|
||||
}
|
||||
],
|
||||
|
|
@ -82,7 +82,7 @@ async def query_memory(
|
|||
):
|
||||
results = VECTOR_DB_CLIENT.search(
|
||||
collection_name=f"user-memory-{user.id}",
|
||||
vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content)],
|
||||
vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content, user)],
|
||||
limit=form_data.k,
|
||||
)
|
||||
|
||||
|
|
@ -105,7 +105,7 @@ async def reset_memory_from_vector_db(
|
|||
{
|
||||
"id": memory.id,
|
||||
"text": memory.content,
|
||||
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content),
|
||||
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content, user),
|
||||
"metadata": {
|
||||
"created_at": memory.created_at,
|
||||
"updated_at": memory.updated_at,
|
||||
|
|
@ -160,7 +160,9 @@ async def update_memory_by_id(
|
|||
{
|
||||
"id": memory.id,
|
||||
"text": memory.content,
|
||||
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content),
|
||||
"vector": request.app.state.EMBEDDING_FUNCTION(
|
||||
memory.content, user
|
||||
),
|
||||
"metadata": {
|
||||
"created_at": memory.created_at,
|
||||
"updated_at": memory.updated_at,
|
||||
|
|
|
|||
|
|
@ -183,7 +183,11 @@ async def delete_model_by_id(id: str, user=Depends(get_verified_user)):
|
|||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
if model.user_id != user.id and user.role != "admin":
|
||||
if (
|
||||
user.role != "admin"
|
||||
and model.user_id != user.id
|
||||
and not has_access(user.id, "write", model.access_control)
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||
|
|
|
|||
|
|
@ -11,11 +11,14 @@ import re
|
|||
import time
|
||||
from typing import Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import aiohttp
|
||||
from aiocache import cached
|
||||
|
||||
import requests
|
||||
from open_webui.models.users import UserModel
|
||||
|
||||
from open_webui.env import (
|
||||
ENABLE_FORWARD_USER_INFO_HEADERS,
|
||||
)
|
||||
|
||||
from fastapi import (
|
||||
Depends,
|
||||
|
|
@ -28,7 +31,7 @@ from fastapi import (
|
|||
)
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, validator
|
||||
from starlette.background import BackgroundTask
|
||||
|
||||
|
||||
|
|
@ -52,7 +55,7 @@ from open_webui.env import (
|
|||
ENV,
|
||||
SRC_LOG_LEVELS,
|
||||
AIOHTTP_CLIENT_TIMEOUT,
|
||||
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST,
|
||||
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST,
|
||||
BYPASS_MODEL_ACCESS_CONTROL,
|
||||
)
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
|
|
@ -68,12 +71,26 @@ log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
|
|||
##########################################
|
||||
|
||||
|
||||
async def send_get_request(url, key=None):
|
||||
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
|
||||
async def send_get_request(url, key=None, user: UserModel = None):
|
||||
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
|
||||
try:
|
||||
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
||||
async with session.get(
|
||||
url, headers={**({"Authorization": f"Bearer {key}"} if key else {})}
|
||||
url,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
) as response:
|
||||
return await response.json()
|
||||
except Exception as e:
|
||||
|
|
@ -98,6 +115,7 @@ async def send_post_request(
|
|||
stream: bool = True,
|
||||
key: Optional[str] = None,
|
||||
content_type: Optional[str] = None,
|
||||
user: UserModel = None,
|
||||
):
|
||||
|
||||
r = None
|
||||
|
|
@ -112,6 +130,16 @@ async def send_post_request(
|
|||
headers={
|
||||
"Content-Type": "application/json",
|
||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
)
|
||||
r.raise_for_status()
|
||||
|
|
@ -188,12 +216,24 @@ async def verify_connection(
|
|||
key = form_data.key
|
||||
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
|
||||
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
|
||||
) as session:
|
||||
try:
|
||||
async with session.get(
|
||||
f"{url}/api/version",
|
||||
headers={**({"Authorization": f"Bearer {key}"} if key else {})},
|
||||
headers={
|
||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
) as r:
|
||||
if r.status != 200:
|
||||
detail = f"HTTP Error: {r.status}"
|
||||
|
|
@ -256,7 +296,7 @@ async def update_config(
|
|||
|
||||
|
||||
@cached(ttl=3)
|
||||
async def get_all_models(request: Request):
|
||||
async def get_all_models(request: Request, user: UserModel = None):
|
||||
log.info("get_all_models()")
|
||||
if request.app.state.config.ENABLE_OLLAMA_API:
|
||||
request_tasks = []
|
||||
|
|
@ -264,7 +304,7 @@ async def get_all_models(request: Request):
|
|||
if (str(idx) not in request.app.state.config.OLLAMA_API_CONFIGS) and (
|
||||
url not in request.app.state.config.OLLAMA_API_CONFIGS # Legacy support
|
||||
):
|
||||
request_tasks.append(send_get_request(f"{url}/api/tags"))
|
||||
request_tasks.append(send_get_request(f"{url}/api/tags", user=user))
|
||||
else:
|
||||
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
|
||||
str(idx),
|
||||
|
|
@ -277,7 +317,9 @@ async def get_all_models(request: Request):
|
|||
key = api_config.get("key", None)
|
||||
|
||||
if enable:
|
||||
request_tasks.append(send_get_request(f"{url}/api/tags", key))
|
||||
request_tasks.append(
|
||||
send_get_request(f"{url}/api/tags", key, user=user)
|
||||
)
|
||||
else:
|
||||
request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
|
||||
|
||||
|
|
@ -362,7 +404,7 @@ async def get_ollama_tags(
|
|||
models = []
|
||||
|
||||
if url_idx is None:
|
||||
models = await get_all_models(request)
|
||||
models = await get_all_models(request, user=user)
|
||||
else:
|
||||
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
|
||||
|
|
@ -372,7 +414,19 @@ async def get_ollama_tags(
|
|||
r = requests.request(
|
||||
method="GET",
|
||||
url=f"{url}/api/tags",
|
||||
headers={**({"Authorization": f"Bearer {key}"} if key else {})},
|
||||
headers={
|
||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
)
|
||||
r.raise_for_status()
|
||||
|
||||
|
|
@ -395,7 +449,7 @@ async def get_ollama_tags(
|
|||
)
|
||||
|
||||
if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
|
||||
models["models"] = get_filtered_models(models, user)
|
||||
models["models"] = await get_filtered_models(models, user)
|
||||
|
||||
return models
|
||||
|
||||
|
|
@ -479,6 +533,7 @@ async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_u
|
|||
url, {}
|
||||
), # Legacy support
|
||||
).get("key", None),
|
||||
user=user,
|
||||
)
|
||||
for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS)
|
||||
]
|
||||
|
|
@ -511,6 +566,7 @@ async def pull_model(
|
|||
url=f"{url}/api/pull",
|
||||
payload=json.dumps(payload),
|
||||
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||
user=user,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -529,7 +585,7 @@ async def push_model(
|
|||
user=Depends(get_admin_user),
|
||||
):
|
||||
if url_idx is None:
|
||||
await get_all_models(request)
|
||||
await get_all_models(request, user=user)
|
||||
models = request.app.state.OLLAMA_MODELS
|
||||
|
||||
if form_data.name in models:
|
||||
|
|
@ -547,6 +603,7 @@ async def push_model(
|
|||
url=f"{url}/api/push",
|
||||
payload=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||
user=user,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -573,6 +630,7 @@ async def create_model(
|
|||
url=f"{url}/api/create",
|
||||
payload=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||
user=user,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -590,7 +648,7 @@ async def copy_model(
|
|||
user=Depends(get_admin_user),
|
||||
):
|
||||
if url_idx is None:
|
||||
await get_all_models(request)
|
||||
await get_all_models(request, user=user)
|
||||
models = request.app.state.OLLAMA_MODELS
|
||||
|
||||
if form_data.source in models:
|
||||
|
|
@ -611,6 +669,16 @@ async def copy_model(
|
|||
headers={
|
||||
"Content-Type": "application/json",
|
||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
data=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
)
|
||||
|
|
@ -645,7 +713,7 @@ async def delete_model(
|
|||
user=Depends(get_admin_user),
|
||||
):
|
||||
if url_idx is None:
|
||||
await get_all_models(request)
|
||||
await get_all_models(request, user=user)
|
||||
models = request.app.state.OLLAMA_MODELS
|
||||
|
||||
if form_data.name in models:
|
||||
|
|
@ -667,6 +735,16 @@ async def delete_model(
|
|||
headers={
|
||||
"Content-Type": "application/json",
|
||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
)
|
||||
r.raise_for_status()
|
||||
|
|
@ -695,7 +773,7 @@ async def delete_model(
|
|||
async def show_model_info(
|
||||
request: Request, form_data: ModelNameForm, user=Depends(get_verified_user)
|
||||
):
|
||||
await get_all_models(request)
|
||||
await get_all_models(request, user=user)
|
||||
models = request.app.state.OLLAMA_MODELS
|
||||
|
||||
if form_data.name not in models:
|
||||
|
|
@ -716,6 +794,16 @@ async def show_model_info(
|
|||
headers={
|
||||
"Content-Type": "application/json",
|
||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
data=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
)
|
||||
|
|
@ -759,7 +847,7 @@ async def embed(
|
|||
log.info(f"generate_ollama_batch_embeddings {form_data}")
|
||||
|
||||
if url_idx is None:
|
||||
await get_all_models(request)
|
||||
await get_all_models(request, user=user)
|
||||
models = request.app.state.OLLAMA_MODELS
|
||||
|
||||
model = form_data.model
|
||||
|
|
@ -785,6 +873,16 @@ async def embed(
|
|||
headers={
|
||||
"Content-Type": "application/json",
|
||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
data=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
)
|
||||
|
|
@ -828,7 +926,7 @@ async def embeddings(
|
|||
log.info(f"generate_ollama_embeddings {form_data}")
|
||||
|
||||
if url_idx is None:
|
||||
await get_all_models(request)
|
||||
await get_all_models(request, user=user)
|
||||
models = request.app.state.OLLAMA_MODELS
|
||||
|
||||
model = form_data.model
|
||||
|
|
@ -854,6 +952,16 @@ async def embeddings(
|
|||
headers={
|
||||
"Content-Type": "application/json",
|
||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
data=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
)
|
||||
|
|
@ -903,7 +1011,7 @@ async def generate_completion(
|
|||
user=Depends(get_verified_user),
|
||||
):
|
||||
if url_idx is None:
|
||||
await get_all_models(request)
|
||||
await get_all_models(request, user=user)
|
||||
models = request.app.state.OLLAMA_MODELS
|
||||
|
||||
model = form_data.model
|
||||
|
|
@ -933,23 +1041,39 @@ async def generate_completion(
|
|||
url=f"{url}/api/generate",
|
||||
payload=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||
user=user,
|
||||
)
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
content: Optional[str] = None
|
||||
tool_calls: Optional[list[dict]] = None
|
||||
images: Optional[list[str]] = None
|
||||
|
||||
@validator("content", pre=True)
|
||||
@classmethod
|
||||
def check_at_least_one_field(cls, field_value, values, **kwargs):
|
||||
# Raise an error if both 'content' and 'tool_calls' are None
|
||||
if field_value is None and (
|
||||
"tool_calls" not in values or values["tool_calls"] is None
|
||||
):
|
||||
raise ValueError(
|
||||
"At least one of 'content' or 'tool_calls' must be provided"
|
||||
)
|
||||
|
||||
return field_value
|
||||
|
||||
|
||||
class GenerateChatCompletionForm(BaseModel):
|
||||
model: str
|
||||
messages: list[ChatMessage]
|
||||
format: Optional[dict] = None
|
||||
format: Optional[Union[dict, str]] = None
|
||||
options: Optional[dict] = None
|
||||
template: Optional[str] = None
|
||||
stream: Optional[bool] = True
|
||||
keep_alive: Optional[Union[int, str]] = None
|
||||
tools: Optional[list[dict]] = None
|
||||
|
||||
|
||||
async def get_ollama_url(request: Request, model: str, url_idx: Optional[int] = None):
|
||||
|
|
@ -977,6 +1101,7 @@ async def generate_chat_completion(
|
|||
if BYPASS_MODEL_ACCESS_CONTROL:
|
||||
bypass_filter = True
|
||||
|
||||
metadata = form_data.pop("metadata", None)
|
||||
try:
|
||||
form_data = GenerateChatCompletionForm(**form_data)
|
||||
except Exception as e:
|
||||
|
|
@ -1006,7 +1131,7 @@ async def generate_chat_completion(
|
|||
payload["options"] = apply_model_params_to_body_ollama(
|
||||
params, payload["options"]
|
||||
)
|
||||
payload = apply_model_system_prompt_to_body(params, payload, user)
|
||||
payload = apply_model_system_prompt_to_body(params, payload, metadata, user)
|
||||
|
||||
# Check if user has access to the model
|
||||
if not bypass_filter and user.role == "user":
|
||||
|
|
@ -1046,6 +1171,7 @@ async def generate_chat_completion(
|
|||
stream=form_data.stream,
|
||||
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||
content_type="application/x-ndjson",
|
||||
user=user,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -1148,6 +1274,7 @@ async def generate_openai_completion(
|
|||
payload=json.dumps(payload),
|
||||
stream=payload.get("stream", False),
|
||||
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||
user=user,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -1159,6 +1286,8 @@ async def generate_openai_chat_completion(
|
|||
url_idx: Optional[int] = None,
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
metadata = form_data.pop("metadata", None)
|
||||
|
||||
try:
|
||||
completion_form = OpenAIChatCompletionForm(**form_data)
|
||||
except Exception as e:
|
||||
|
|
@ -1185,7 +1314,7 @@ async def generate_openai_chat_completion(
|
|||
|
||||
if params:
|
||||
payload = apply_model_params_to_body_openai(params, payload)
|
||||
payload = apply_model_system_prompt_to_body(params, payload, user)
|
||||
payload = apply_model_system_prompt_to_body(params, payload, metadata, user)
|
||||
|
||||
# Check if user has access to the model
|
||||
if user.role == "user":
|
||||
|
|
@ -1224,6 +1353,7 @@ async def generate_openai_chat_completion(
|
|||
payload=json.dumps(payload),
|
||||
stream=payload.get("stream", False),
|
||||
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||
user=user,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -1237,7 +1367,7 @@ async def get_openai_models(
|
|||
|
||||
models = []
|
||||
if url_idx is None:
|
||||
model_list = await get_all_models(request)
|
||||
model_list = await get_all_models(request, user=user)
|
||||
models = [
|
||||
{
|
||||
"id": model["model"],
|
||||
|
|
@ -1405,9 +1535,10 @@ async def download_model(
|
|||
return None
|
||||
|
||||
|
||||
# TODO: Progress bar does not reflect size & duration of upload.
|
||||
@router.post("/models/upload")
|
||||
@router.post("/models/upload/{url_idx}")
|
||||
def upload_model(
|
||||
async def upload_model(
|
||||
request: Request,
|
||||
file: UploadFile = File(...),
|
||||
url_idx: Optional[int] = None,
|
||||
|
|
@ -1416,59 +1547,85 @@ def upload_model(
|
|||
if url_idx is None:
|
||||
url_idx = 0
|
||||
ollama_url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
file_path = os.path.join(UPLOAD_DIR, file.filename)
|
||||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||
|
||||
file_path = f"{UPLOAD_DIR}/{file.filename}"
|
||||
# --- P1: save file locally ---
|
||||
chunk_size = 1024 * 1024 * 2 # 2 MB chunks
|
||||
with open(file_path, "wb") as out_f:
|
||||
while True:
|
||||
chunk = file.file.read(chunk_size)
|
||||
# log.info(f"Chunk: {str(chunk)}") # DEBUG
|
||||
if not chunk:
|
||||
break
|
||||
out_f.write(chunk)
|
||||
|
||||
# Save file in chunks
|
||||
with open(file_path, "wb+") as f:
|
||||
for chunk in file.file:
|
||||
f.write(chunk)
|
||||
|
||||
def file_process_stream():
|
||||
async def file_process_stream():
|
||||
nonlocal ollama_url
|
||||
total_size = os.path.getsize(file_path)
|
||||
chunk_size = 1024 * 1024
|
||||
log.info(f"Total Model Size: {str(total_size)}") # DEBUG
|
||||
|
||||
# --- P2: SSE progress + calculate sha256 hash ---
|
||||
file_hash = calculate_sha256(file_path, chunk_size)
|
||||
log.info(f"Model Hash: {str(file_hash)}") # DEBUG
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
total = 0
|
||||
done = False
|
||||
|
||||
while not done:
|
||||
chunk = f.read(chunk_size)
|
||||
if not chunk:
|
||||
done = True
|
||||
continue
|
||||
|
||||
total += len(chunk)
|
||||
progress = round((total / total_size) * 100, 2)
|
||||
|
||||
res = {
|
||||
bytes_read = 0
|
||||
while chunk := f.read(chunk_size):
|
||||
bytes_read += len(chunk)
|
||||
progress = round(bytes_read / total_size * 100, 2)
|
||||
data_msg = {
|
||||
"progress": progress,
|
||||
"total": total_size,
|
||||
"completed": total,
|
||||
"completed": bytes_read,
|
||||
}
|
||||
yield f"data: {json.dumps(res)}\n\n"
|
||||
yield f"data: {json.dumps(data_msg)}\n\n"
|
||||
|
||||
if done:
|
||||
f.seek(0)
|
||||
hashed = calculate_sha256(f)
|
||||
f.seek(0)
|
||||
# --- P3: Upload to ollama /api/blobs ---
|
||||
with open(file_path, "rb") as f:
|
||||
url = f"{ollama_url}/api/blobs/sha256:{file_hash}"
|
||||
response = requests.post(url, data=f)
|
||||
|
||||
url = f"{ollama_url}/api/blobs/sha256:{hashed}"
|
||||
response = requests.post(url, data=f)
|
||||
if response.ok:
|
||||
log.info(f"Uploaded to /api/blobs") # DEBUG
|
||||
# Remove local file
|
||||
os.remove(file_path)
|
||||
|
||||
if response.ok:
|
||||
res = {
|
||||
"done": done,
|
||||
"blob": f"sha256:{hashed}",
|
||||
"name": file.filename,
|
||||
}
|
||||
os.remove(file_path)
|
||||
yield f"data: {json.dumps(res)}\n\n"
|
||||
else:
|
||||
raise Exception(
|
||||
"Ollama: Could not create blob, Please try again."
|
||||
)
|
||||
# Create model in ollama
|
||||
model_name, ext = os.path.splitext(file.filename)
|
||||
log.info(f"Created Model: {model_name}") # DEBUG
|
||||
|
||||
create_payload = {
|
||||
"model": model_name,
|
||||
# Reference the file by its original name => the uploaded blob's digest
|
||||
"files": {file.filename: f"sha256:{file_hash}"},
|
||||
}
|
||||
log.info(f"Model Payload: {create_payload}") # DEBUG
|
||||
|
||||
# Call ollama /api/create
|
||||
# https://github.com/ollama/ollama/blob/main/docs/api.md#create-a-model
|
||||
create_resp = requests.post(
|
||||
url=f"{ollama_url}/api/create",
|
||||
headers={"Content-Type": "application/json"},
|
||||
data=json.dumps(create_payload),
|
||||
)
|
||||
|
||||
if create_resp.ok:
|
||||
log.info(f"API SUCCESS!") # DEBUG
|
||||
done_msg = {
|
||||
"done": True,
|
||||
"blob": f"sha256:{file_hash}",
|
||||
"name": file.filename,
|
||||
"model_created": model_name,
|
||||
}
|
||||
yield f"data: {json.dumps(done_msg)}\n\n"
|
||||
else:
|
||||
raise Exception(
|
||||
f"Failed to create model in Ollama. {create_resp.text}"
|
||||
)
|
||||
|
||||
else:
|
||||
raise Exception("Ollama: Could not create blob, Please try again.")
|
||||
|
||||
except Exception as e:
|
||||
res = {"error": str(e)}
|
||||
|
|
|
|||
|
|
@ -22,10 +22,11 @@ from open_webui.config import (
|
|||
)
|
||||
from open_webui.env import (
|
||||
AIOHTTP_CLIENT_TIMEOUT,
|
||||
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST,
|
||||
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST,
|
||||
ENABLE_FORWARD_USER_INFO_HEADERS,
|
||||
BYPASS_MODEL_ACCESS_CONTROL,
|
||||
)
|
||||
from open_webui.models.users import UserModel
|
||||
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import ENV, SRC_LOG_LEVELS
|
||||
|
|
@ -51,12 +52,25 @@ log.setLevel(SRC_LOG_LEVELS["OPENAI"])
|
|||
##########################################
|
||||
|
||||
|
||||
async def send_get_request(url, key=None):
|
||||
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
|
||||
async def send_get_request(url, key=None, user: UserModel = None):
|
||||
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
|
||||
try:
|
||||
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
||||
async with session.get(
|
||||
url, headers={**({"Authorization": f"Bearer {key}"} if key else {})}
|
||||
url,
|
||||
headers={
|
||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
) as response:
|
||||
return await response.json()
|
||||
except Exception as e:
|
||||
|
|
@ -75,18 +89,24 @@ async def cleanup_response(
|
|||
await session.close()
|
||||
|
||||
|
||||
def openai_o1_handler(payload):
|
||||
def openai_o1_o3_handler(payload):
|
||||
"""
|
||||
Handle O1 specific parameters
|
||||
Handle o1, o3 specific parameters
|
||||
"""
|
||||
if "max_tokens" in payload:
|
||||
# Remove "max_tokens" from the payload
|
||||
payload["max_completion_tokens"] = payload["max_tokens"]
|
||||
del payload["max_tokens"]
|
||||
|
||||
# Fix: O1 does not support the "system" parameter, Modify "system" to "user"
|
||||
# Fix: o1 and o3 do not support the "system" role directly.
|
||||
# For older models like "o1-mini" or "o1-preview", use role "user".
|
||||
# For newer o1/o3 models, replace "system" with "developer".
|
||||
if payload["messages"][0]["role"] == "system":
|
||||
payload["messages"][0]["role"] = "user"
|
||||
model_lower = payload["model"].lower()
|
||||
if model_lower.startswith("o1-mini") or model_lower.startswith("o1-preview"):
|
||||
payload["messages"][0]["role"] = "user"
|
||||
else:
|
||||
payload["messages"][0]["role"] = "developer"
|
||||
|
||||
return payload
|
||||
|
||||
|
|
@ -172,7 +192,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
|||
body = await request.body()
|
||||
name = hashlib.sha256(body).hexdigest()
|
||||
|
||||
SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
|
||||
SPEECH_CACHE_DIR = CACHE_DIR / "audio" / "speech"
|
||||
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
|
||||
file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
|
||||
|
|
@ -247,7 +267,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
|||
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
|
||||
|
||||
|
||||
async def get_all_models_responses(request: Request) -> list:
|
||||
async def get_all_models_responses(request: Request, user: UserModel) -> list:
|
||||
if not request.app.state.config.ENABLE_OPENAI_API:
|
||||
return []
|
||||
|
||||
|
|
@ -271,7 +291,9 @@ async def get_all_models_responses(request: Request) -> list:
|
|||
):
|
||||
request_tasks.append(
|
||||
send_get_request(
|
||||
f"{url}/models", request.app.state.config.OPENAI_API_KEYS[idx]
|
||||
f"{url}/models",
|
||||
request.app.state.config.OPENAI_API_KEYS[idx],
|
||||
user=user,
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
|
@ -291,6 +313,7 @@ async def get_all_models_responses(request: Request) -> list:
|
|||
send_get_request(
|
||||
f"{url}/models",
|
||||
request.app.state.config.OPENAI_API_KEYS[idx],
|
||||
user=user,
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
|
@ -352,13 +375,13 @@ async def get_filtered_models(models, user):
|
|||
|
||||
|
||||
@cached(ttl=3)
|
||||
async def get_all_models(request: Request) -> dict[str, list]:
|
||||
async def get_all_models(request: Request, user: UserModel) -> dict[str, list]:
|
||||
log.info("get_all_models()")
|
||||
|
||||
if not request.app.state.config.ENABLE_OPENAI_API:
|
||||
return {"data": []}
|
||||
|
||||
responses = await get_all_models_responses(request)
|
||||
responses = await get_all_models_responses(request, user=user)
|
||||
|
||||
def extract_data(response):
|
||||
if response and "data" in response:
|
||||
|
|
@ -418,16 +441,14 @@ async def get_models(
|
|||
}
|
||||
|
||||
if url_idx is None:
|
||||
models = await get_all_models(request)
|
||||
models = await get_all_models(request, user=user)
|
||||
else:
|
||||
url = request.app.state.config.OPENAI_API_BASE_URLS[url_idx]
|
||||
key = request.app.state.config.OPENAI_API_KEYS[url_idx]
|
||||
|
||||
r = None
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST
|
||||
)
|
||||
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
|
||||
) as session:
|
||||
try:
|
||||
async with session.get(
|
||||
|
|
@ -489,7 +510,7 @@ async def get_models(
|
|||
raise HTTPException(status_code=500, detail=error_detail)
|
||||
|
||||
if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
|
||||
models["data"] = get_filtered_models(models, user)
|
||||
models["data"] = await get_filtered_models(models, user)
|
||||
|
||||
return models
|
||||
|
||||
|
|
@ -507,7 +528,7 @@ async def verify_connection(
|
|||
key = form_data.key
|
||||
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
|
||||
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
|
||||
) as session:
|
||||
try:
|
||||
async with session.get(
|
||||
|
|
@ -515,6 +536,16 @@ async def verify_connection(
|
|||
headers={
|
||||
"Authorization": f"Bearer {key}",
|
||||
"Content-Type": "application/json",
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS
|
||||
else {}
|
||||
),
|
||||
},
|
||||
) as r:
|
||||
if r.status != 200:
|
||||
|
|
@ -551,9 +582,9 @@ async def generate_chat_completion(
|
|||
bypass_filter = True
|
||||
|
||||
idx = 0
|
||||
|
||||
payload = {**form_data}
|
||||
if "metadata" in payload:
|
||||
del payload["metadata"]
|
||||
metadata = payload.pop("metadata", None)
|
||||
|
||||
model_id = form_data.get("model")
|
||||
model_info = Models.get_model_by_id(model_id)
|
||||
|
|
@ -566,7 +597,7 @@ async def generate_chat_completion(
|
|||
|
||||
params = model_info.params.model_dump()
|
||||
payload = apply_model_params_to_body_openai(params, payload)
|
||||
payload = apply_model_system_prompt_to_body(params, payload, user)
|
||||
payload = apply_model_system_prompt_to_body(params, payload, metadata, user)
|
||||
|
||||
# Check if user has access to the model
|
||||
if not bypass_filter and user.role == "user":
|
||||
|
|
@ -587,7 +618,7 @@ async def generate_chat_completion(
|
|||
detail="Model not found",
|
||||
)
|
||||
|
||||
await get_all_models(request)
|
||||
await get_all_models(request, user=user)
|
||||
model = request.app.state.OPENAI_MODELS.get(model_id)
|
||||
if model:
|
||||
idx = model["urlIdx"]
|
||||
|
|
@ -621,10 +652,10 @@ async def generate_chat_completion(
|
|||
url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
|
||||
key = request.app.state.config.OPENAI_API_KEYS[idx]
|
||||
|
||||
# Fix: O1 does not support the "max_tokens" parameter, Modify "max_tokens" to "max_completion_tokens"
|
||||
is_o1 = payload["model"].lower().startswith("o1-")
|
||||
if is_o1:
|
||||
payload = openai_o1_handler(payload)
|
||||
# Fix: o1,o3 does not support the "max_tokens" parameter, Modify "max_tokens" to "max_completion_tokens"
|
||||
is_o1_o3 = payload["model"].lower().startswith(("o1", "o3-"))
|
||||
if is_o1_o3:
|
||||
payload = openai_o1_o3_handler(payload)
|
||||
elif "api.openai.com" not in url:
|
||||
# Remove "max_completion_tokens" from the payload for backward compatibility
|
||||
if "max_completion_tokens" in payload:
|
||||
|
|
@ -777,7 +808,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
|||
if r is not None:
|
||||
try:
|
||||
res = await r.json()
|
||||
print(res)
|
||||
log.error(res)
|
||||
if "error" in res:
|
||||
detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
|
||||
except Exception:
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from fastapi import (
|
|||
status,
|
||||
APIRouter,
|
||||
)
|
||||
import aiohttp
|
||||
import os
|
||||
import logging
|
||||
import shutil
|
||||
|
|
@ -56,96 +57,103 @@ def get_sorted_filters(model_id, models):
|
|||
return sorted_filters
|
||||
|
||||
|
||||
def process_pipeline_inlet_filter(request, payload, user, models):
|
||||
async def process_pipeline_inlet_filter(request, payload, user, models):
|
||||
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
|
||||
model_id = payload["model"]
|
||||
|
||||
sorted_filters = get_sorted_filters(model_id, models)
|
||||
model = models[model_id]
|
||||
|
||||
if "pipeline" in model:
|
||||
sorted_filters.append(model)
|
||||
|
||||
for filter in sorted_filters:
|
||||
r = None
|
||||
try:
|
||||
urlIdx = filter["urlIdx"]
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for filter in sorted_filters:
|
||||
urlIdx = filter.get("urlIdx")
|
||||
if urlIdx is None:
|
||||
continue
|
||||
|
||||
url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
||||
key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
|
||||
|
||||
if key == "":
|
||||
if not key:
|
||||
continue
|
||||
|
||||
headers = {"Authorization": f"Bearer {key}"}
|
||||
r = requests.post(
|
||||
f"{url}/{filter['id']}/filter/inlet",
|
||||
headers=headers,
|
||||
json={
|
||||
"user": user,
|
||||
"body": payload,
|
||||
},
|
||||
)
|
||||
request_data = {
|
||||
"user": user,
|
||||
"body": payload,
|
||||
}
|
||||
|
||||
r.raise_for_status()
|
||||
payload = r.json()
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
|
||||
if r is not None:
|
||||
res = r.json()
|
||||
try:
|
||||
async with session.post(
|
||||
f"{url}/{filter['id']}/filter/inlet",
|
||||
headers=headers,
|
||||
json=request_data,
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
payload = await response.json()
|
||||
except aiohttp.ClientResponseError as e:
|
||||
res = (
|
||||
await response.json()
|
||||
if response.content_type == "application/json"
|
||||
else {}
|
||||
)
|
||||
if "detail" in res:
|
||||
raise Exception(r.status_code, res["detail"])
|
||||
raise Exception(response.status, res["detail"])
|
||||
except Exception as e:
|
||||
log.exception(f"Connection error: {e}")
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
def process_pipeline_outlet_filter(request, payload, user, models):
|
||||
async def process_pipeline_outlet_filter(request, payload, user, models):
|
||||
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
|
||||
model_id = payload["model"]
|
||||
|
||||
sorted_filters = get_sorted_filters(model_id, models)
|
||||
model = models[model_id]
|
||||
|
||||
if "pipeline" in model:
|
||||
sorted_filters = [model] + sorted_filters
|
||||
|
||||
for filter in sorted_filters:
|
||||
r = None
|
||||
try:
|
||||
urlIdx = filter["urlIdx"]
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for filter in sorted_filters:
|
||||
urlIdx = filter.get("urlIdx")
|
||||
if urlIdx is None:
|
||||
continue
|
||||
|
||||
url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
||||
key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
|
||||
|
||||
if key != "":
|
||||
r = requests.post(
|
||||
if not key:
|
||||
continue
|
||||
|
||||
headers = {"Authorization": f"Bearer {key}"}
|
||||
request_data = {
|
||||
"user": user,
|
||||
"body": payload,
|
||||
}
|
||||
|
||||
try:
|
||||
async with session.post(
|
||||
f"{url}/{filter['id']}/filter/outlet",
|
||||
headers={"Authorization": f"Bearer {key}"},
|
||||
json={
|
||||
"user": user,
|
||||
"body": payload,
|
||||
},
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
payload = data
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
|
||||
if r is not None:
|
||||
headers=headers,
|
||||
json=request_data,
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
payload = await response.json()
|
||||
except aiohttp.ClientResponseError as e:
|
||||
try:
|
||||
res = r.json()
|
||||
res = (
|
||||
await response.json()
|
||||
if "application/json" in response.content_type
|
||||
else {}
|
||||
)
|
||||
if "detail" in res:
|
||||
return Exception(r.status_code, res)
|
||||
raise Exception(response.status, res)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
else:
|
||||
pass
|
||||
except Exception as e:
|
||||
log.exception(f"Connection error: {e}")
|
||||
|
||||
return payload
|
||||
|
||||
|
|
@ -161,7 +169,7 @@ router = APIRouter()
|
|||
|
||||
@router.get("/list")
|
||||
async def get_pipelines_list(request: Request, user=Depends(get_admin_user)):
|
||||
responses = await get_all_models_responses(request)
|
||||
responses = await get_all_models_responses(request, user)
|
||||
log.debug(f"get_pipelines_list: get_openai_models_responses returned {responses}")
|
||||
|
||||
urlIdxs = [
|
||||
|
|
@ -188,7 +196,7 @@ async def upload_pipeline(
|
|||
file: UploadFile = File(...),
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
print("upload_pipeline", urlIdx, file.filename)
|
||||
log.info(f"upload_pipeline: urlIdx={urlIdx}, filename={file.filename}")
|
||||
# Check if the uploaded file is a python file
|
||||
if not (file.filename and file.filename.endswith(".py")):
|
||||
raise HTTPException(
|
||||
|
|
@ -223,7 +231,7 @@ async def upload_pipeline(
|
|||
return {**data}
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
log.exception(f"Connection error: {e}")
|
||||
|
||||
detail = None
|
||||
status_code = status.HTTP_404_NOT_FOUND
|
||||
|
|
@ -274,7 +282,7 @@ async def add_pipeline(
|
|||
return {**data}
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
log.exception(f"Connection error: {e}")
|
||||
|
||||
detail = None
|
||||
if r is not None:
|
||||
|
|
@ -319,7 +327,7 @@ async def delete_pipeline(
|
|||
return {**data}
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
log.exception(f"Connection error: {e}")
|
||||
|
||||
detail = None
|
||||
if r is not None:
|
||||
|
|
@ -353,7 +361,7 @@ async def get_pipelines(
|
|||
return {**data}
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
log.exception(f"Connection error: {e}")
|
||||
|
||||
detail = None
|
||||
if r is not None:
|
||||
|
|
@ -392,7 +400,7 @@ async def get_pipeline_valves(
|
|||
return {**data}
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
log.exception(f"Connection error: {e}")
|
||||
|
||||
detail = None
|
||||
if r is not None:
|
||||
|
|
@ -432,7 +440,7 @@ async def get_pipeline_valves_spec(
|
|||
return {**data}
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
log.exception(f"Connection error: {e}")
|
||||
|
||||
detail = None
|
||||
if r is not None:
|
||||
|
|
@ -474,7 +482,7 @@ async def update_pipeline_valves(
|
|||
return {**data}
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
log.exception(f"Connection error: {e}")
|
||||
|
||||
detail = None
|
||||
|
||||
|
|
|
|||
|
|
@ -147,7 +147,11 @@ async def delete_prompt_by_command(command: str, user=Depends(get_verified_user)
|
|||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
if prompt.user_id != user.id and user.role != "admin":
|
||||
if (
|
||||
prompt.user_id != user.id
|
||||
and not has_access(user.id, "write", prompt.access_control)
|
||||
and user.role != "admin"
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ from fastapi import (
|
|||
APIRouter,
|
||||
)
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from pydantic import BaseModel
|
||||
import tiktoken
|
||||
|
||||
|
|
@ -45,17 +46,20 @@ from open_webui.retrieval.web.utils import get_web_loader
|
|||
from open_webui.retrieval.web.brave import search_brave
|
||||
from open_webui.retrieval.web.kagi import search_kagi
|
||||
from open_webui.retrieval.web.mojeek import search_mojeek
|
||||
from open_webui.retrieval.web.bocha import search_bocha
|
||||
from open_webui.retrieval.web.duckduckgo import search_duckduckgo
|
||||
from open_webui.retrieval.web.google_pse import search_google_pse
|
||||
from open_webui.retrieval.web.jina_search import search_jina
|
||||
from open_webui.retrieval.web.searchapi import search_searchapi
|
||||
from open_webui.retrieval.web.serpapi import search_serpapi
|
||||
from open_webui.retrieval.web.searxng import search_searxng
|
||||
from open_webui.retrieval.web.serper import search_serper
|
||||
from open_webui.retrieval.web.serply import search_serply
|
||||
from open_webui.retrieval.web.serpstack import search_serpstack
|
||||
from open_webui.retrieval.web.tavily import search_tavily
|
||||
from open_webui.retrieval.web.bing import search_bing
|
||||
|
||||
from open_webui.retrieval.web.exa import search_exa
|
||||
from open_webui.retrieval.web.perplexity import search_perplexity
|
||||
|
||||
from open_webui.retrieval.utils import (
|
||||
get_embedding_function,
|
||||
|
|
@ -347,11 +351,18 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
|
|||
return {
|
||||
"status": True,
|
||||
"pdf_extract_images": request.app.state.config.PDF_EXTRACT_IMAGES,
|
||||
"RAG_FULL_CONTEXT": request.app.state.config.RAG_FULL_CONTEXT,
|
||||
"BYPASS_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL,
|
||||
"enable_google_drive_integration": request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
|
||||
"enable_onedrive_integration": request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION,
|
||||
"content_extraction": {
|
||||
"engine": request.app.state.config.CONTENT_EXTRACTION_ENGINE,
|
||||
"tika_server_url": request.app.state.config.TIKA_SERVER_URL,
|
||||
"docling_server_url": request.app.state.config.DOCLING_SERVER_URL,
|
||||
"document_intelligence_config": {
|
||||
"endpoint": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
|
||||
"key": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
|
||||
},
|
||||
},
|
||||
"chunk": {
|
||||
"text_splitter": request.app.state.config.TEXT_SPLITTER,
|
||||
|
|
@ -368,10 +379,12 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
|
|||
"proxy_url": request.app.state.config.YOUTUBE_LOADER_PROXY_URL,
|
||||
},
|
||||
"web": {
|
||||
"web_loader_ssl_verification": request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
||||
"ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION": request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
||||
"BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL,
|
||||
"search": {
|
||||
"enabled": request.app.state.config.ENABLE_RAG_WEB_SEARCH,
|
||||
"drive": request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
|
||||
"onedrive": request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION,
|
||||
"engine": request.app.state.config.RAG_WEB_SEARCH_ENGINE,
|
||||
"searxng_query_url": request.app.state.config.SEARXNG_QUERY_URL,
|
||||
"google_pse_api_key": request.app.state.config.GOOGLE_PSE_API_KEY,
|
||||
|
|
@ -379,6 +392,7 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
|
|||
"brave_search_api_key": request.app.state.config.BRAVE_SEARCH_API_KEY,
|
||||
"kagi_search_api_key": request.app.state.config.KAGI_SEARCH_API_KEY,
|
||||
"mojeek_search_api_key": request.app.state.config.MOJEEK_SEARCH_API_KEY,
|
||||
"bocha_search_api_key": request.app.state.config.BOCHA_SEARCH_API_KEY,
|
||||
"serpstack_api_key": request.app.state.config.SERPSTACK_API_KEY,
|
||||
"serpstack_https": request.app.state.config.SERPSTACK_HTTPS,
|
||||
"serper_api_key": request.app.state.config.SERPER_API_KEY,
|
||||
|
|
@ -386,11 +400,17 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
|
|||
"tavily_api_key": request.app.state.config.TAVILY_API_KEY,
|
||||
"searchapi_api_key": request.app.state.config.SEARCHAPI_API_KEY,
|
||||
"searchapi_engine": request.app.state.config.SEARCHAPI_ENGINE,
|
||||
"serpapi_api_key": request.app.state.config.SERPAPI_API_KEY,
|
||||
"serpapi_engine": request.app.state.config.SERPAPI_ENGINE,
|
||||
"jina_api_key": request.app.state.config.JINA_API_KEY,
|
||||
"bing_search_v7_endpoint": request.app.state.config.BING_SEARCH_V7_ENDPOINT,
|
||||
"bing_search_v7_subscription_key": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
|
||||
"exa_api_key": request.app.state.config.EXA_API_KEY,
|
||||
"perplexity_api_key": request.app.state.config.PERPLEXITY_API_KEY,
|
||||
"result_count": request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
"trust_env": request.app.state.config.RAG_WEB_SEARCH_TRUST_ENV,
|
||||
"concurrent_requests": request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||
"domain_filter_list": request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
@ -401,10 +421,16 @@ class FileConfig(BaseModel):
|
|||
max_count: Optional[int] = None
|
||||
|
||||
|
||||
class DocumentIntelligenceConfigForm(BaseModel):
|
||||
endpoint: str
|
||||
key: str
|
||||
|
||||
|
||||
class ContentExtractionConfig(BaseModel):
|
||||
engine: str = ""
|
||||
tika_server_url: Optional[str] = None
|
||||
docling_server_url: Optional[str] = None
|
||||
document_intelligence_config: Optional[DocumentIntelligenceConfigForm] = None
|
||||
|
||||
|
||||
class ChunkParamUpdateForm(BaseModel):
|
||||
|
|
@ -428,6 +454,7 @@ class WebSearchConfig(BaseModel):
|
|||
brave_search_api_key: Optional[str] = None
|
||||
kagi_search_api_key: Optional[str] = None
|
||||
mojeek_search_api_key: Optional[str] = None
|
||||
bocha_search_api_key: Optional[str] = None
|
||||
serpstack_api_key: Optional[str] = None
|
||||
serpstack_https: Optional[bool] = None
|
||||
serper_api_key: Optional[str] = None
|
||||
|
|
@ -435,21 +462,31 @@ class WebSearchConfig(BaseModel):
|
|||
tavily_api_key: Optional[str] = None
|
||||
searchapi_api_key: Optional[str] = None
|
||||
searchapi_engine: Optional[str] = None
|
||||
serpapi_api_key: Optional[str] = None
|
||||
serpapi_engine: Optional[str] = None
|
||||
jina_api_key: Optional[str] = None
|
||||
bing_search_v7_endpoint: Optional[str] = None
|
||||
bing_search_v7_subscription_key: Optional[str] = None
|
||||
exa_api_key: Optional[str] = None
|
||||
perplexity_api_key: Optional[str] = None
|
||||
result_count: Optional[int] = None
|
||||
concurrent_requests: Optional[int] = None
|
||||
trust_env: Optional[bool] = None
|
||||
domain_filter_list: Optional[List[str]] = []
|
||||
|
||||
|
||||
class WebConfig(BaseModel):
|
||||
search: WebSearchConfig
|
||||
web_loader_ssl_verification: Optional[bool] = None
|
||||
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION: Optional[bool] = None
|
||||
BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL: Optional[bool] = None
|
||||
|
||||
|
||||
class ConfigUpdateForm(BaseModel):
|
||||
RAG_FULL_CONTEXT: Optional[bool] = None
|
||||
BYPASS_EMBEDDING_AND_RETRIEVAL: Optional[bool] = None
|
||||
pdf_extract_images: Optional[bool] = None
|
||||
enable_google_drive_integration: Optional[bool] = None
|
||||
enable_onedrive_integration: Optional[bool] = None
|
||||
file: Optional[FileConfig] = None
|
||||
content_extraction: Optional[ContentExtractionConfig] = None
|
||||
chunk: Optional[ChunkParamUpdateForm] = None
|
||||
|
|
@ -467,18 +504,38 @@ async def update_rag_config(
|
|||
else request.app.state.config.PDF_EXTRACT_IMAGES
|
||||
)
|
||||
|
||||
request.app.state.config.RAG_FULL_CONTEXT = (
|
||||
form_data.RAG_FULL_CONTEXT
|
||||
if form_data.RAG_FULL_CONTEXT is not None
|
||||
else request.app.state.config.RAG_FULL_CONTEXT
|
||||
)
|
||||
|
||||
request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL = (
|
||||
form_data.BYPASS_EMBEDDING_AND_RETRIEVAL
|
||||
if form_data.BYPASS_EMBEDDING_AND_RETRIEVAL is not None
|
||||
else request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
|
||||
)
|
||||
|
||||
request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION = (
|
||||
form_data.enable_google_drive_integration
|
||||
if form_data.enable_google_drive_integration is not None
|
||||
else request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION
|
||||
)
|
||||
|
||||
request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION = (
|
||||
form_data.enable_onedrive_integration
|
||||
if form_data.enable_onedrive_integration is not None
|
||||
else request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION
|
||||
)
|
||||
|
||||
if form_data.file is not None:
|
||||
request.app.state.config.FILE_MAX_SIZE = form_data.file.max_size
|
||||
request.app.state.config.FILE_MAX_COUNT = form_data.file.max_count
|
||||
|
||||
if form_data.content_extraction is not None:
|
||||
log.info(f"Updating text settings: {form_data.content_extraction}")
|
||||
log.info(
|
||||
f"Updating content extraction: {request.app.state.config.CONTENT_EXTRACTION_ENGINE} to {form_data.content_extraction.engine}"
|
||||
)
|
||||
request.app.state.config.CONTENT_EXTRACTION_ENGINE = (
|
||||
form_data.content_extraction.engine
|
||||
)
|
||||
|
|
@ -488,6 +545,13 @@ async def update_rag_config(
|
|||
request.app.state.config.DOCLING_SERVER_URL = (
|
||||
form_data.content_extraction.docling_server_url
|
||||
)
|
||||
if form_data.content_extraction.document_intelligence_config is not None:
|
||||
request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = (
|
||||
form_data.content_extraction.document_intelligence_config.endpoint
|
||||
)
|
||||
request.app.state.config.DOCUMENT_INTELLIGENCE_KEY = (
|
||||
form_data.content_extraction.document_intelligence_config.key
|
||||
)
|
||||
|
||||
if form_data.chunk is not None:
|
||||
request.app.state.config.TEXT_SPLITTER = form_data.chunk.text_splitter
|
||||
|
|
@ -502,11 +566,16 @@ async def update_rag_config(
|
|||
if form_data.web is not None:
|
||||
request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
|
||||
# Note: When UI "Bypass SSL verification for Websites"=True then ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION=False
|
||||
form_data.web.web_loader_ssl_verification
|
||||
form_data.web.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
|
||||
)
|
||||
|
||||
request.app.state.config.ENABLE_RAG_WEB_SEARCH = form_data.web.search.enabled
|
||||
request.app.state.config.RAG_WEB_SEARCH_ENGINE = form_data.web.search.engine
|
||||
|
||||
request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = (
|
||||
form_data.web.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
|
||||
)
|
||||
|
||||
request.app.state.config.SEARXNG_QUERY_URL = (
|
||||
form_data.web.search.searxng_query_url
|
||||
)
|
||||
|
|
@ -525,6 +594,9 @@ async def update_rag_config(
|
|||
request.app.state.config.MOJEEK_SEARCH_API_KEY = (
|
||||
form_data.web.search.mojeek_search_api_key
|
||||
)
|
||||
request.app.state.config.BOCHA_SEARCH_API_KEY = (
|
||||
form_data.web.search.bocha_search_api_key
|
||||
)
|
||||
request.app.state.config.SERPSTACK_API_KEY = (
|
||||
form_data.web.search.serpstack_api_key
|
||||
)
|
||||
|
|
@ -539,6 +611,9 @@ async def update_rag_config(
|
|||
form_data.web.search.searchapi_engine
|
||||
)
|
||||
|
||||
request.app.state.config.SERPAPI_API_KEY = form_data.web.search.serpapi_api_key
|
||||
request.app.state.config.SERPAPI_ENGINE = form_data.web.search.serpapi_engine
|
||||
|
||||
request.app.state.config.JINA_API_KEY = form_data.web.search.jina_api_key
|
||||
request.app.state.config.BING_SEARCH_V7_ENDPOINT = (
|
||||
form_data.web.search.bing_search_v7_endpoint
|
||||
|
|
@ -547,16 +622,30 @@ async def update_rag_config(
|
|||
form_data.web.search.bing_search_v7_subscription_key
|
||||
)
|
||||
|
||||
request.app.state.config.EXA_API_KEY = form_data.web.search.exa_api_key
|
||||
|
||||
request.app.state.config.PERPLEXITY_API_KEY = (
|
||||
form_data.web.search.perplexity_api_key
|
||||
)
|
||||
|
||||
request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = (
|
||||
form_data.web.search.result_count
|
||||
)
|
||||
request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = (
|
||||
form_data.web.search.concurrent_requests
|
||||
)
|
||||
request.app.state.config.RAG_WEB_SEARCH_TRUST_ENV = (
|
||||
form_data.web.search.trust_env
|
||||
)
|
||||
request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = (
|
||||
form_data.web.search.domain_filter_list
|
||||
)
|
||||
|
||||
return {
|
||||
"status": True,
|
||||
"pdf_extract_images": request.app.state.config.PDF_EXTRACT_IMAGES,
|
||||
"RAG_FULL_CONTEXT": request.app.state.config.RAG_FULL_CONTEXT,
|
||||
"BYPASS_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL,
|
||||
"file": {
|
||||
"max_size": request.app.state.config.FILE_MAX_SIZE,
|
||||
"max_count": request.app.state.config.FILE_MAX_COUNT,
|
||||
|
|
@ -565,6 +654,10 @@ async def update_rag_config(
|
|||
"engine": request.app.state.config.CONTENT_EXTRACTION_ENGINE,
|
||||
"tika_server_url": request.app.state.config.TIKA_SERVER_URL,
|
||||
"docling_server_url": request.app.state.config.DOCLING_SERVER_URL,
|
||||
"document_intelligence_config": {
|
||||
"endpoint": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
|
||||
"key": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
|
||||
},
|
||||
},
|
||||
"chunk": {
|
||||
"text_splitter": request.app.state.config.TEXT_SPLITTER,
|
||||
|
|
@ -577,7 +670,8 @@ async def update_rag_config(
|
|||
"translation": request.app.state.YOUTUBE_LOADER_TRANSLATION,
|
||||
},
|
||||
"web": {
|
||||
"web_loader_ssl_verification": request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
||||
"ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION": request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
||||
"BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL,
|
||||
"search": {
|
||||
"enabled": request.app.state.config.ENABLE_RAG_WEB_SEARCH,
|
||||
"engine": request.app.state.config.RAG_WEB_SEARCH_ENGINE,
|
||||
|
|
@ -587,18 +681,25 @@ async def update_rag_config(
|
|||
"brave_search_api_key": request.app.state.config.BRAVE_SEARCH_API_KEY,
|
||||
"kagi_search_api_key": request.app.state.config.KAGI_SEARCH_API_KEY,
|
||||
"mojeek_search_api_key": request.app.state.config.MOJEEK_SEARCH_API_KEY,
|
||||
"bocha_search_api_key": request.app.state.config.BOCHA_SEARCH_API_KEY,
|
||||
"serpstack_api_key": request.app.state.config.SERPSTACK_API_KEY,
|
||||
"serpstack_https": request.app.state.config.SERPSTACK_HTTPS,
|
||||
"serper_api_key": request.app.state.config.SERPER_API_KEY,
|
||||
"serply_api_key": request.app.state.config.SERPLY_API_KEY,
|
||||
"serachapi_api_key": request.app.state.config.SEARCHAPI_API_KEY,
|
||||
"searchapi_engine": request.app.state.config.SEARCHAPI_ENGINE,
|
||||
"serpapi_api_key": request.app.state.config.SERPAPI_API_KEY,
|
||||
"serpapi_engine": request.app.state.config.SERPAPI_ENGINE,
|
||||
"tavily_api_key": request.app.state.config.TAVILY_API_KEY,
|
||||
"jina_api_key": request.app.state.config.JINA_API_KEY,
|
||||
"bing_search_v7_endpoint": request.app.state.config.BING_SEARCH_V7_ENDPOINT,
|
||||
"bing_search_v7_subscription_key": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
|
||||
"exa_api_key": request.app.state.config.EXA_API_KEY,
|
||||
"perplexity_api_key": request.app.state.config.PERPLEXITY_API_KEY,
|
||||
"result_count": request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
"concurrent_requests": request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||
"trust_env": request.app.state.config.RAG_WEB_SEARCH_TRUST_ENV,
|
||||
"domain_filter_list": request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
@ -666,6 +767,7 @@ def save_docs_to_vector_db(
|
|||
overwrite: bool = False,
|
||||
split: bool = True,
|
||||
add: bool = False,
|
||||
user=None,
|
||||
) -> bool:
|
||||
def _get_docs_info(docs: list[Document]) -> str:
|
||||
docs_info = set()
|
||||
|
|
@ -746,7 +848,11 @@ def save_docs_to_vector_db(
|
|||
# for meta-data so convert them to string.
|
||||
for metadata in metadatas:
|
||||
for key, value in metadata.items():
|
||||
if isinstance(value, datetime):
|
||||
if (
|
||||
isinstance(value, datetime)
|
||||
or isinstance(value, list)
|
||||
or isinstance(value, dict)
|
||||
):
|
||||
metadata[key] = str(value)
|
||||
|
||||
try:
|
||||
|
|
@ -781,7 +887,7 @@ def save_docs_to_vector_db(
|
|||
)
|
||||
|
||||
embeddings = embedding_function(
|
||||
list(map(lambda x: x.replace("\n", " "), texts))
|
||||
list(map(lambda x: x.replace("\n", " "), texts)), user=user
|
||||
)
|
||||
|
||||
items = [
|
||||
|
|
@ -829,7 +935,12 @@ def process_file(
|
|||
# Update the content in the file
|
||||
# Usage: /files/{file_id}/data/content/update
|
||||
|
||||
VECTOR_DB_CLIENT.delete_collection(collection_name=f"file-{file.id}")
|
||||
try:
|
||||
# /files/{file_id}/data/content/update
|
||||
VECTOR_DB_CLIENT.delete_collection(collection_name=f"file-{file.id}")
|
||||
except:
|
||||
# Audio file upload pipeline
|
||||
pass
|
||||
|
||||
docs = [
|
||||
Document(
|
||||
|
|
@ -887,6 +998,8 @@ def process_file(
|
|||
TIKA_SERVER_URL=request.app.state.config.TIKA_SERVER_URL,
|
||||
DOCLING_SERVER_URL=request.app.state.config.DOCLING_SERVER_URL,
|
||||
PDF_EXTRACT_IMAGES=request.app.state.config.PDF_EXTRACT_IMAGES,
|
||||
DOCUMENT_INTELLIGENCE_ENDPOINT=request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
|
||||
DOCUMENT_INTELLIGENCE_KEY=request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
|
||||
)
|
||||
docs = loader.load(
|
||||
file.filename, file.meta.get("content_type"), file_path
|
||||
|
|
@ -929,35 +1042,45 @@ def process_file(
|
|||
hash = calculate_sha256_string(text_content)
|
||||
Files.update_file_hash_by_id(file.id, hash)
|
||||
|
||||
try:
|
||||
result = save_docs_to_vector_db(
|
||||
request,
|
||||
docs=docs,
|
||||
collection_name=collection_name,
|
||||
metadata={
|
||||
"file_id": file.id,
|
||||
"name": file.filename,
|
||||
"hash": hash,
|
||||
},
|
||||
add=(True if form_data.collection_name else False),
|
||||
)
|
||||
|
||||
if result:
|
||||
Files.update_file_metadata_by_id(
|
||||
file.id,
|
||||
{
|
||||
"collection_name": collection_name,
|
||||
if not request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL:
|
||||
try:
|
||||
result = save_docs_to_vector_db(
|
||||
request,
|
||||
docs=docs,
|
||||
collection_name=collection_name,
|
||||
metadata={
|
||||
"file_id": file.id,
|
||||
"name": file.filename,
|
||||
"hash": hash,
|
||||
},
|
||||
add=(True if form_data.collection_name else False),
|
||||
user=user,
|
||||
)
|
||||
|
||||
return {
|
||||
"status": True,
|
||||
"collection_name": collection_name,
|
||||
"filename": file.filename,
|
||||
"content": text_content,
|
||||
}
|
||||
except Exception as e:
|
||||
raise e
|
||||
if result:
|
||||
Files.update_file_metadata_by_id(
|
||||
file.id,
|
||||
{
|
||||
"collection_name": collection_name,
|
||||
},
|
||||
)
|
||||
|
||||
return {
|
||||
"status": True,
|
||||
"collection_name": collection_name,
|
||||
"filename": file.filename,
|
||||
"content": text_content,
|
||||
}
|
||||
except Exception as e:
|
||||
raise e
|
||||
else:
|
||||
return {
|
||||
"status": True,
|
||||
"collection_name": None,
|
||||
"filename": file.filename,
|
||||
"content": text_content,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
if "No pandoc was found" in str(e):
|
||||
|
|
@ -997,7 +1120,7 @@ def process_text(
|
|||
text_content = form_data.content
|
||||
log.debug(f"text_content: {text_content}")
|
||||
|
||||
result = save_docs_to_vector_db(request, docs, collection_name)
|
||||
result = save_docs_to_vector_db(request, docs, collection_name, user=user)
|
||||
if result:
|
||||
return {
|
||||
"status": True,
|
||||
|
|
@ -1030,7 +1153,9 @@ def process_youtube_video(
|
|||
content = " ".join([doc.page_content for doc in docs])
|
||||
log.debug(f"text_content: {content}")
|
||||
|
||||
save_docs_to_vector_db(request, docs, collection_name, overwrite=True)
|
||||
save_docs_to_vector_db(
|
||||
request, docs, collection_name, overwrite=True, user=user
|
||||
)
|
||||
|
||||
return {
|
||||
"status": True,
|
||||
|
|
@ -1071,7 +1196,13 @@ def process_web(
|
|||
content = " ".join([doc.page_content for doc in docs])
|
||||
|
||||
log.debug(f"text_content: {content}")
|
||||
save_docs_to_vector_db(request, docs, collection_name, overwrite=True)
|
||||
|
||||
if not request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL:
|
||||
save_docs_to_vector_db(
|
||||
request, docs, collection_name, overwrite=True, user=user
|
||||
)
|
||||
else:
|
||||
collection_name = None
|
||||
|
||||
return {
|
||||
"status": True,
|
||||
|
|
@ -1083,6 +1214,7 @@ def process_web(
|
|||
},
|
||||
"meta": {
|
||||
"name": form_data.url,
|
||||
"source": form_data.url,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
@ -1102,11 +1234,15 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
|
|||
- BRAVE_SEARCH_API_KEY
|
||||
- KAGI_SEARCH_API_KEY
|
||||
- MOJEEK_SEARCH_API_KEY
|
||||
- BOCHA_SEARCH_API_KEY
|
||||
- SERPSTACK_API_KEY
|
||||
- SERPER_API_KEY
|
||||
- SERPLY_API_KEY
|
||||
- TAVILY_API_KEY
|
||||
- EXA_API_KEY
|
||||
- PERPLEXITY_API_KEY
|
||||
- SEARCHAPI_API_KEY + SEARCHAPI_ENGINE (by default `google`)
|
||||
- SERPAPI_API_KEY + SERPAPI_ENGINE (by default `google`)
|
||||
Args:
|
||||
query (str): The query to search for
|
||||
"""
|
||||
|
|
@ -1168,6 +1304,16 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
|
|||
)
|
||||
else:
|
||||
raise Exception("No MOJEEK_SEARCH_API_KEY found in environment variables")
|
||||
elif engine == "bocha":
|
||||
if request.app.state.config.BOCHA_SEARCH_API_KEY:
|
||||
return search_bocha(
|
||||
request.app.state.config.BOCHA_SEARCH_API_KEY,
|
||||
query,
|
||||
request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
)
|
||||
else:
|
||||
raise Exception("No BOCHA_SEARCH_API_KEY found in environment variables")
|
||||
elif engine == "serpstack":
|
||||
if request.app.state.config.SERPSTACK_API_KEY:
|
||||
return search_serpstack(
|
||||
|
|
@ -1211,6 +1357,7 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
|
|||
request.app.state.config.TAVILY_API_KEY,
|
||||
query,
|
||||
request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
)
|
||||
else:
|
||||
raise Exception("No TAVILY_API_KEY found in environment variables")
|
||||
|
|
@ -1225,6 +1372,17 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
|
|||
)
|
||||
else:
|
||||
raise Exception("No SEARCHAPI_API_KEY found in environment variables")
|
||||
elif engine == "serpapi":
|
||||
if request.app.state.config.SERPAPI_API_KEY:
|
||||
return search_serpapi(
|
||||
request.app.state.config.SERPAPI_API_KEY,
|
||||
request.app.state.config.SERPAPI_ENGINE,
|
||||
query,
|
||||
request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
)
|
||||
else:
|
||||
raise Exception("No SERPAPI_API_KEY found in environment variables")
|
||||
elif engine == "jina":
|
||||
return search_jina(
|
||||
request.app.state.config.JINA_API_KEY,
|
||||
|
|
@ -1240,12 +1398,26 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
|
|||
request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
)
|
||||
elif engine == "exa":
|
||||
return search_exa(
|
||||
request.app.state.config.EXA_API_KEY,
|
||||
query,
|
||||
request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
)
|
||||
elif engine == "perplexity":
|
||||
return search_perplexity(
|
||||
request.app.state.config.PERPLEXITY_API_KEY,
|
||||
query,
|
||||
request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
)
|
||||
else:
|
||||
raise Exception("No search engine API key found in environment variables")
|
||||
|
||||
|
||||
@router.post("/process/web/search")
|
||||
def process_web_search(
|
||||
async def process_web_search(
|
||||
request: Request, form_data: SearchForm, user=Depends(get_verified_user)
|
||||
):
|
||||
try:
|
||||
|
|
@ -1277,15 +1449,40 @@ def process_web_search(
|
|||
urls,
|
||||
verify_ssl=request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
||||
requests_per_second=request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||
trust_env=request.app.state.config.RAG_WEB_SEARCH_TRUST_ENV,
|
||||
)
|
||||
docs = loader.load()
|
||||
save_docs_to_vector_db(request, docs, collection_name, overwrite=True)
|
||||
docs = await loader.aload()
|
||||
|
||||
return {
|
||||
"status": True,
|
||||
"collection_name": collection_name,
|
||||
"filenames": urls,
|
||||
}
|
||||
if request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL:
|
||||
return {
|
||||
"status": True,
|
||||
"collection_name": None,
|
||||
"filenames": urls,
|
||||
"docs": [
|
||||
{
|
||||
"content": doc.page_content,
|
||||
"metadata": doc.metadata,
|
||||
}
|
||||
for doc in docs
|
||||
],
|
||||
"loaded_count": len(docs),
|
||||
}
|
||||
else:
|
||||
await run_in_threadpool(
|
||||
save_docs_to_vector_db,
|
||||
request,
|
||||
docs,
|
||||
collection_name,
|
||||
overwrite=True,
|
||||
user=user,
|
||||
)
|
||||
|
||||
return {
|
||||
"status": True,
|
||||
"collection_name": collection_name,
|
||||
"filenames": urls,
|
||||
"loaded_count": len(docs),
|
||||
}
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
|
|
@ -1313,7 +1510,9 @@ def query_doc_handler(
|
|||
return query_doc_with_hybrid_search(
|
||||
collection_name=form_data.collection_name,
|
||||
query=form_data.query,
|
||||
embedding_function=request.app.state.EMBEDDING_FUNCTION,
|
||||
embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION(
|
||||
query, user=user
|
||||
),
|
||||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
||||
reranking_function=request.app.state.rf,
|
||||
r=(
|
||||
|
|
@ -1321,12 +1520,16 @@ def query_doc_handler(
|
|||
if form_data.r
|
||||
else request.app.state.config.RELEVANCE_THRESHOLD
|
||||
),
|
||||
user=user,
|
||||
)
|
||||
else:
|
||||
return query_doc(
|
||||
collection_name=form_data.collection_name,
|
||||
query_embedding=request.app.state.EMBEDDING_FUNCTION(form_data.query),
|
||||
query_embedding=request.app.state.EMBEDDING_FUNCTION(
|
||||
form_data.query, user=user
|
||||
),
|
||||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
||||
user=user,
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
|
@ -1355,7 +1558,9 @@ def query_collection_handler(
|
|||
return query_collection_with_hybrid_search(
|
||||
collection_names=form_data.collection_names,
|
||||
queries=[form_data.query],
|
||||
embedding_function=request.app.state.EMBEDDING_FUNCTION,
|
||||
embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION(
|
||||
query, user=user
|
||||
),
|
||||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
||||
reranking_function=request.app.state.rf,
|
||||
r=(
|
||||
|
|
@ -1368,7 +1573,9 @@ def query_collection_handler(
|
|||
return query_collection(
|
||||
collection_names=form_data.collection_names,
|
||||
queries=[form_data.query],
|
||||
embedding_function=request.app.state.EMBEDDING_FUNCTION,
|
||||
embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION(
|
||||
query, user=user
|
||||
),
|
||||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
||||
)
|
||||
|
||||
|
|
@ -1432,11 +1639,11 @@ def reset_upload_dir(user=Depends(get_admin_user)) -> bool:
|
|||
elif os.path.isdir(file_path):
|
||||
shutil.rmtree(file_path) # Remove the directory
|
||||
except Exception as e:
|
||||
print(f"Failed to delete {file_path}. Reason: {e}")
|
||||
log.exception(f"Failed to delete {file_path}. Reason: {e}")
|
||||
else:
|
||||
print(f"The directory {folder} does not exist")
|
||||
log.warning(f"The directory {folder} does not exist")
|
||||
except Exception as e:
|
||||
print(f"Failed to process the directory {folder}. Reason: {e}")
|
||||
log.exception(f"Failed to process the directory {folder}. Reason: {e}")
|
||||
return True
|
||||
|
||||
|
||||
|
|
@ -1516,6 +1723,7 @@ def process_files_batch(
|
|||
docs=all_docs,
|
||||
collection_name=collection_name,
|
||||
add=True,
|
||||
user=user,
|
||||
)
|
||||
|
||||
# Update all files with collection name
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from fastapi.responses import JSONResponse, RedirectResponse
|
|||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
import logging
|
||||
import re
|
||||
|
||||
from open_webui.utils.chat import generate_chat_completion
|
||||
from open_webui.utils.task import (
|
||||
|
|
@ -19,6 +20,10 @@ from open_webui.utils.auth import get_admin_user, get_verified_user
|
|||
from open_webui.constants import TASKS
|
||||
|
||||
from open_webui.routers.pipelines import process_pipeline_inlet_filter
|
||||
from open_webui.utils.filter import (
|
||||
get_sorted_filter_ids,
|
||||
process_filter_functions,
|
||||
)
|
||||
from open_webui.utils.task import get_task_model_id
|
||||
|
||||
from open_webui.config import (
|
||||
|
|
@ -57,6 +62,7 @@ async def get_task_config(request: Request, user=Depends(get_verified_user)):
|
|||
"AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
|
||||
"TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
|
||||
"ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
|
||||
"ENABLE_TITLE_GENERATION": request.app.state.config.ENABLE_TITLE_GENERATION,
|
||||
"ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
|
||||
"ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
|
||||
"QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||
|
|
@ -67,6 +73,7 @@ async def get_task_config(request: Request, user=Depends(get_verified_user)):
|
|||
class TaskConfigForm(BaseModel):
|
||||
TASK_MODEL: Optional[str]
|
||||
TASK_MODEL_EXTERNAL: Optional[str]
|
||||
ENABLE_TITLE_GENERATION: bool
|
||||
TITLE_GENERATION_PROMPT_TEMPLATE: str
|
||||
IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE: str
|
||||
ENABLE_AUTOCOMPLETE_GENERATION: bool
|
||||
|
|
@ -85,10 +92,15 @@ async def update_task_config(
|
|||
):
|
||||
request.app.state.config.TASK_MODEL = form_data.TASK_MODEL
|
||||
request.app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL
|
||||
request.app.state.config.ENABLE_TITLE_GENERATION = form_data.ENABLE_TITLE_GENERATION
|
||||
request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = (
|
||||
form_data.TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
)
|
||||
|
||||
request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = (
|
||||
form_data.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
|
||||
)
|
||||
|
||||
request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = (
|
||||
form_data.ENABLE_AUTOCOMPLETE_GENERATION
|
||||
)
|
||||
|
|
@ -117,6 +129,7 @@ async def update_task_config(
|
|||
return {
|
||||
"TASK_MODEL": request.app.state.config.TASK_MODEL,
|
||||
"TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
|
||||
"ENABLE_TITLE_GENERATION": request.app.state.config.ENABLE_TITLE_GENERATION,
|
||||
"TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
|
||||
"IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
|
||||
"ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
|
||||
|
|
@ -134,7 +147,19 @@ async def update_task_config(
|
|||
async def generate_title(
|
||||
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||
):
|
||||
models = request.app.state.MODELS
|
||||
|
||||
if not request.app.state.config.ENABLE_TITLE_GENERATION:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content={"detail": "Title generation is disabled"},
|
||||
)
|
||||
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
models = {
|
||||
request.state.model["id"]: request.state.model,
|
||||
}
|
||||
else:
|
||||
models = request.app.state.MODELS
|
||||
|
||||
model_id = form_data["model"]
|
||||
if model_id not in models:
|
||||
|
|
@ -161,9 +186,20 @@ async def generate_title(
|
|||
else:
|
||||
template = DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
|
||||
messages = form_data["messages"]
|
||||
|
||||
# Remove reasoning details from the messages
|
||||
for message in messages:
|
||||
message["content"] = re.sub(
|
||||
r"<details\s+type=\"reasoning\"[^>]*>.*?<\/details>",
|
||||
"",
|
||||
message["content"],
|
||||
flags=re.S,
|
||||
).strip()
|
||||
|
||||
content = title_generation_template(
|
||||
template,
|
||||
form_data["messages"],
|
||||
messages,
|
||||
{
|
||||
"name": user.name,
|
||||
"location": user.info.get("location") if user.info else None,
|
||||
|
|
@ -175,19 +211,26 @@ async def generate_title(
|
|||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": False,
|
||||
**(
|
||||
{"max_tokens": 50}
|
||||
if models[task_model_id]["owned_by"] == "ollama"
|
||||
{"max_tokens": 1000}
|
||||
if models[task_model_id].get("owned_by") == "ollama"
|
||||
else {
|
||||
"max_completion_tokens": 50,
|
||||
"max_completion_tokens": 1000,
|
||||
}
|
||||
),
|
||||
"metadata": {
|
||||
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
|
||||
"task": str(TASKS.TITLE_GENERATION),
|
||||
"task_body": form_data,
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
},
|
||||
}
|
||||
|
||||
# Process the payload through the pipeline
|
||||
try:
|
||||
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||
except Exception as e:
|
||||
|
|
@ -209,7 +252,12 @@ async def generate_chat_tags(
|
|||
content={"detail": "Tags generation is disabled"},
|
||||
)
|
||||
|
||||
models = request.app.state.MODELS
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
models = {
|
||||
request.state.model["id"]: request.state.model,
|
||||
}
|
||||
else:
|
||||
models = request.app.state.MODELS
|
||||
|
||||
model_id = form_data["model"]
|
||||
if model_id not in models:
|
||||
|
|
@ -245,12 +293,19 @@ async def generate_chat_tags(
|
|||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": False,
|
||||
"metadata": {
|
||||
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
|
||||
"task": str(TASKS.TAGS_GENERATION),
|
||||
"task_body": form_data,
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
},
|
||||
}
|
||||
|
||||
# Process the payload through the pipeline
|
||||
try:
|
||||
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||
except Exception as e:
|
||||
|
|
@ -265,7 +320,12 @@ async def generate_chat_tags(
|
|||
async def generate_image_prompt(
|
||||
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||
):
|
||||
models = request.app.state.MODELS
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
models = {
|
||||
request.state.model["id"]: request.state.model,
|
||||
}
|
||||
else:
|
||||
models = request.app.state.MODELS
|
||||
|
||||
model_id = form_data["model"]
|
||||
if model_id not in models:
|
||||
|
|
@ -305,12 +365,19 @@ async def generate_image_prompt(
|
|||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": False,
|
||||
"metadata": {
|
||||
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
|
||||
"task": str(TASKS.IMAGE_PROMPT_GENERATION),
|
||||
"task_body": form_data,
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
},
|
||||
}
|
||||
|
||||
# Process the payload through the pipeline
|
||||
try:
|
||||
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||
except Exception as e:
|
||||
|
|
@ -340,7 +407,12 @@ async def generate_queries(
|
|||
detail=f"Query generation is disabled",
|
||||
)
|
||||
|
||||
models = request.app.state.MODELS
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
models = {
|
||||
request.state.model["id"]: request.state.model,
|
||||
}
|
||||
else:
|
||||
models = request.app.state.MODELS
|
||||
|
||||
model_id = form_data["model"]
|
||||
if model_id not in models:
|
||||
|
|
@ -376,12 +448,19 @@ async def generate_queries(
|
|||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": False,
|
||||
"metadata": {
|
||||
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
|
||||
"task": str(TASKS.QUERY_GENERATION),
|
||||
"task_body": form_data,
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
},
|
||||
}
|
||||
|
||||
# Process the payload through the pipeline
|
||||
try:
|
||||
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||
except Exception as e:
|
||||
|
|
@ -415,7 +494,12 @@ async def generate_autocompletion(
|
|||
detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}",
|
||||
)
|
||||
|
||||
models = request.app.state.MODELS
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
models = {
|
||||
request.state.model["id"]: request.state.model,
|
||||
}
|
||||
else:
|
||||
models = request.app.state.MODELS
|
||||
|
||||
model_id = form_data["model"]
|
||||
if model_id not in models:
|
||||
|
|
@ -451,12 +535,19 @@ async def generate_autocompletion(
|
|||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": False,
|
||||
"metadata": {
|
||||
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
|
||||
"task": str(TASKS.AUTOCOMPLETE_GENERATION),
|
||||
"task_body": form_data,
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
},
|
||||
}
|
||||
|
||||
# Process the payload through the pipeline
|
||||
try:
|
||||
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||
except Exception as e:
|
||||
|
|
@ -472,7 +563,12 @@ async def generate_emoji(
|
|||
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||
):
|
||||
|
||||
models = request.app.state.MODELS
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
models = {
|
||||
request.state.model["id"]: request.state.model,
|
||||
}
|
||||
else:
|
||||
models = request.app.state.MODELS
|
||||
|
||||
model_id = form_data["model"]
|
||||
if model_id not in models:
|
||||
|
|
@ -509,15 +605,25 @@ async def generate_emoji(
|
|||
"stream": False,
|
||||
**(
|
||||
{"max_tokens": 4}
|
||||
if models[task_model_id]["owned_by"] == "ollama"
|
||||
if models[task_model_id].get("owned_by") == "ollama"
|
||||
else {
|
||||
"max_completion_tokens": 4,
|
||||
}
|
||||
),
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
"metadata": {"task": str(TASKS.EMOJI_GENERATION), "task_body": form_data},
|
||||
"metadata": {
|
||||
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
|
||||
"task": str(TASKS.EMOJI_GENERATION),
|
||||
"task_body": form_data,
|
||||
},
|
||||
}
|
||||
|
||||
# Process the payload through the pipeline
|
||||
try:
|
||||
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||
except Exception as e:
|
||||
|
|
@ -532,7 +638,13 @@ async def generate_moa_response(
|
|||
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||
):
|
||||
|
||||
models = request.app.state.MODELS
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
models = {
|
||||
request.state.model["id"]: request.state.model,
|
||||
}
|
||||
else:
|
||||
models = request.app.state.MODELS
|
||||
|
||||
model_id = form_data["model"]
|
||||
|
||||
if model_id not in models:
|
||||
|
|
@ -565,12 +677,19 @@ async def generate_moa_response(
|
|||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": form_data.get("stream", False),
|
||||
"metadata": {
|
||||
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
"task": str(TASKS.MOA_RESPONSE_GENERATION),
|
||||
"task_body": form_data,
|
||||
},
|
||||
}
|
||||
|
||||
# Process the payload through the pipeline
|
||||
try:
|
||||
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
|
@ -15,6 +16,10 @@ from fastapi import APIRouter, Depends, HTTPException, Request, status
|
|||
from open_webui.utils.tools import get_tools_specs
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.utils.access_control import has_access, has_permission
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
|
@ -100,7 +105,7 @@ async def create_new_tools(
|
|||
specs = get_tools_specs(TOOLS[form_data.id])
|
||||
tools = Tools.insert_new_tool(user.id, form_data, specs)
|
||||
|
||||
tool_cache_dir = Path(CACHE_DIR) / "tools" / form_data.id
|
||||
tool_cache_dir = CACHE_DIR / "tools" / form_data.id
|
||||
tool_cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if tools:
|
||||
|
|
@ -111,7 +116,7 @@ async def create_new_tools(
|
|||
detail=ERROR_MESSAGES.DEFAULT("Error creating tools"),
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Failed to load the tool by id {form_data.id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(str(e)),
|
||||
|
|
@ -193,7 +198,7 @@ async def update_tools_by_id(
|
|||
"specs": specs,
|
||||
}
|
||||
|
||||
print(updated)
|
||||
log.debug(updated)
|
||||
tools = Tools.update_tool_by_id(id, updated)
|
||||
|
||||
if tools:
|
||||
|
|
@ -227,7 +232,11 @@ async def delete_tools_by_id(
|
|||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
if tools.user_id != user.id and user.role != "admin":
|
||||
if (
|
||||
tools.user_id != user.id
|
||||
and not has_access(user.id, "write", tools.access_control)
|
||||
and user.role != "admin"
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||
|
|
@ -339,7 +348,7 @@ async def update_tools_valves_by_id(
|
|||
Tools.update_tool_valves_by_id(id, valves.model_dump())
|
||||
return valves.model_dump()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Failed to update tool valves by id {id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(str(e)),
|
||||
|
|
@ -417,7 +426,7 @@ async def update_tools_user_valves_by_id(
|
|||
)
|
||||
return user_valves.model_dump()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Failed to update user valves by id {id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(str(e)),
|
||||
|
|
|
|||
|
|
@ -79,6 +79,7 @@ class ChatPermissions(BaseModel):
|
|||
class FeaturesPermissions(BaseModel):
|
||||
web_search: bool = True
|
||||
image_generation: bool = True
|
||||
code_interpreter: bool = True
|
||||
|
||||
|
||||
class UserPermissions(BaseModel):
|
||||
|
|
@ -152,7 +153,7 @@ async def get_user_settings_by_session_user(user=Depends(get_verified_user)):
|
|||
async def update_user_settings_by_session_user(
|
||||
form_data: UserSettings, user=Depends(get_verified_user)
|
||||
):
|
||||
user = Users.update_user_by_id(user.id, {"settings": form_data.model_dump()})
|
||||
user = Users.update_user_settings_by_id(user.id, form_data.model_dump())
|
||||
if user:
|
||||
return user.settings
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -1,48 +1,84 @@
|
|||
import black
|
||||
import logging
|
||||
import markdown
|
||||
|
||||
from open_webui.models.chats import ChatTitleMessagesForm
|
||||
from open_webui.config import DATA_DIR, ENABLE_ADMIN_EXPORT
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from fastapi import APIRouter, Depends, HTTPException, Response, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||
from pydantic import BaseModel
|
||||
from starlette.responses import FileResponse
|
||||
|
||||
|
||||
from open_webui.utils.misc import get_gravatar_url
|
||||
from open_webui.utils.pdf_generator import PDFGenerator
|
||||
from open_webui.utils.auth import get_admin_user
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.utils.code_interpreter import execute_code_jupyter
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/gravatar")
|
||||
async def get_gravatar(
|
||||
email: str,
|
||||
):
|
||||
async def get_gravatar(email: str, user=Depends(get_verified_user)):
|
||||
return get_gravatar_url(email)
|
||||
|
||||
|
||||
class CodeFormatRequest(BaseModel):
|
||||
class CodeForm(BaseModel):
|
||||
code: str
|
||||
|
||||
|
||||
@router.post("/code/format")
|
||||
async def format_code(request: CodeFormatRequest):
|
||||
async def format_code(form_data: CodeForm, user=Depends(get_verified_user)):
|
||||
try:
|
||||
formatted_code = black.format_str(request.code, mode=black.Mode())
|
||||
formatted_code = black.format_str(form_data.code, mode=black.Mode())
|
||||
return {"code": formatted_code}
|
||||
except black.NothingChanged:
|
||||
return {"code": request.code}
|
||||
return {"code": form_data.code}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/code/execute")
|
||||
async def execute_code(
|
||||
request: Request, form_data: CodeForm, user=Depends(get_verified_user)
|
||||
):
|
||||
if request.app.state.config.CODE_EXECUTION_ENGINE == "jupyter":
|
||||
output = await execute_code_jupyter(
|
||||
request.app.state.config.CODE_EXECUTION_JUPYTER_URL,
|
||||
form_data.code,
|
||||
(
|
||||
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN
|
||||
if request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH == "token"
|
||||
else None
|
||||
),
|
||||
(
|
||||
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD
|
||||
if request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH == "password"
|
||||
else None
|
||||
),
|
||||
request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT,
|
||||
)
|
||||
|
||||
return output
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Code execution engine not supported",
|
||||
)
|
||||
|
||||
|
||||
class MarkdownForm(BaseModel):
|
||||
md: str
|
||||
|
||||
|
||||
@router.post("/markdown")
|
||||
async def get_html_from_markdown(
|
||||
form_data: MarkdownForm,
|
||||
form_data: MarkdownForm, user=Depends(get_verified_user)
|
||||
):
|
||||
return {"html": markdown.markdown(form_data.md)}
|
||||
|
||||
|
|
@ -54,7 +90,7 @@ class ChatForm(BaseModel):
|
|||
|
||||
@router.post("/pdf")
|
||||
async def download_chat_as_pdf(
|
||||
form_data: ChatTitleMessagesForm,
|
||||
form_data: ChatTitleMessagesForm, user=Depends(get_verified_user)
|
||||
):
|
||||
try:
|
||||
pdf_bytes = PDFGenerator(form_data).generate_chat_pdf()
|
||||
|
|
@ -65,7 +101,7 @@ async def download_chat_as_pdf(
|
|||
headers={"Content-Disposition": "attachment;filename=chat.pdf"},
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Error generating PDF: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from open_webui.env import (
|
|||
ENABLE_WEBSOCKET_SUPPORT,
|
||||
WEBSOCKET_MANAGER,
|
||||
WEBSOCKET_REDIS_URL,
|
||||
WEBSOCKET_REDIS_LOCK_TIMEOUT,
|
||||
)
|
||||
from open_webui.utils.auth import decode_token
|
||||
from open_webui.socket.utils import RedisDict, RedisLock
|
||||
|
|
@ -61,7 +62,7 @@ if WEBSOCKET_MANAGER == "redis":
|
|||
clean_up_lock = RedisLock(
|
||||
redis_url=WEBSOCKET_REDIS_URL,
|
||||
lock_name="usage_cleanup_lock",
|
||||
timeout_secs=TIMEOUT_DURATION * 2,
|
||||
timeout_secs=WEBSOCKET_REDIS_LOCK_TIMEOUT,
|
||||
)
|
||||
aquire_func = clean_up_lock.aquire_lock
|
||||
renew_func = clean_up_lock.renew_lock
|
||||
|
|
@ -279,8 +280,8 @@ def get_event_emitter(request_info):
|
|||
await sio.emit(
|
||||
"chat-events",
|
||||
{
|
||||
"chat_id": request_info["chat_id"],
|
||||
"message_id": request_info["message_id"],
|
||||
"chat_id": request_info.get("chat_id", None),
|
||||
"message_id": request_info.get("message_id", None),
|
||||
"data": event_data,
|
||||
},
|
||||
to=session_id,
|
||||
|
|
@ -325,19 +326,22 @@ def get_event_emitter(request_info):
|
|||
|
||||
|
||||
def get_event_call(request_info):
|
||||
async def __event_call__(event_data):
|
||||
async def __event_caller__(event_data):
|
||||
response = await sio.call(
|
||||
"chat-events",
|
||||
{
|
||||
"chat_id": request_info["chat_id"],
|
||||
"message_id": request_info["message_id"],
|
||||
"chat_id": request_info.get("chat_id", None),
|
||||
"message_id": request_info.get("message_id", None),
|
||||
"data": event_data,
|
||||
},
|
||||
to=request_info["session_id"],
|
||||
)
|
||||
return response
|
||||
|
||||
return __event_call__
|
||||
return __event_caller__
|
||||
|
||||
|
||||
get_event_caller = get_event_call
|
||||
|
||||
|
||||
def get_user_id_from_session_pool(sid):
|
||||
|
|
|
|||
|
After Width: | Height: | Size: 7.3 KiB |
|
After Width: | Height: | Size: 3.7 KiB |
|
After Width: | Height: | Size: 16 KiB |
|
After Width: | Height: | Size: 15 KiB |
|
After Width: | Height: | Size: 14 KiB |
|
|
@ -0,0 +1,21 @@
|
|||
{
|
||||
"name": "Open WebUI",
|
||||
"short_name": "WebUI",
|
||||
"icons": [
|
||||
{
|
||||
"src": "/static/web-app-manifest-192x192.png",
|
||||
"sizes": "192x192",
|
||||
"type": "image/png",
|
||||
"purpose": "maskable"
|
||||
},
|
||||
{
|
||||
"src": "/static/web-app-manifest-512x512.png",
|
||||
"sizes": "512x512",
|
||||
"type": "image/png",
|
||||
"purpose": "maskable"
|
||||
}
|
||||
],
|
||||
"theme_color": "#ffffff",
|
||||
"background_color": "#ffffff",
|
||||
"display": "standalone"
|
||||
}
|
||||
|
After Width: | Height: | Size: 5.3 KiB |
|
|
@ -9308,5 +9308,3 @@
|
|||
.json-schema-2020-12__title:first-of-type {
|
||||
font-size: 16px;
|
||||
}
|
||||
|
||||
/*# sourceMappingURL=swagger-ui.css.map*/
|
||||
|
|
|
|||
|
After Width: | Height: | Size: 8.2 KiB |
|
After Width: | Height: | Size: 29 KiB |
|
|
@ -1,25 +1,41 @@
|
|||
import os
|
||||
import shutil
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import BinaryIO, Tuple
|
||||
|
||||
import boto3
|
||||
from botocore.config import Config
|
||||
from botocore.exceptions import ClientError
|
||||
from open_webui.config import (
|
||||
S3_ACCESS_KEY_ID,
|
||||
S3_BUCKET_NAME,
|
||||
S3_ENDPOINT_URL,
|
||||
S3_KEY_PREFIX,
|
||||
S3_REGION_NAME,
|
||||
S3_SECRET_ACCESS_KEY,
|
||||
S3_USE_ACCELERATE_ENDPOINT,
|
||||
S3_ADDRESSING_STYLE,
|
||||
GCS_BUCKET_NAME,
|
||||
GOOGLE_APPLICATION_CREDENTIALS_JSON,
|
||||
AZURE_STORAGE_ENDPOINT,
|
||||
AZURE_STORAGE_CONTAINER_NAME,
|
||||
AZURE_STORAGE_KEY,
|
||||
STORAGE_PROVIDER,
|
||||
UPLOAD_DIR,
|
||||
)
|
||||
from google.cloud import storage
|
||||
from google.cloud.exceptions import GoogleCloudError, NotFound
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from azure.identity import DefaultAzureCredential
|
||||
from azure.storage.blob import BlobServiceClient
|
||||
from azure.core.exceptions import ResourceNotFoundError
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
|
||||
class StorageProvider(ABC):
|
||||
|
|
@ -64,7 +80,7 @@ class LocalStorageProvider(StorageProvider):
|
|||
if os.path.isfile(file_path):
|
||||
os.remove(file_path)
|
||||
else:
|
||||
print(f"File {file_path} not found in local storage.")
|
||||
log.warning(f"File {file_path} not found in local storage.")
|
||||
|
||||
@staticmethod
|
||||
def delete_all_files() -> None:
|
||||
|
|
@ -78,30 +94,52 @@ class LocalStorageProvider(StorageProvider):
|
|||
elif os.path.isdir(file_path):
|
||||
shutil.rmtree(file_path) # Remove the directory
|
||||
except Exception as e:
|
||||
print(f"Failed to delete {file_path}. Reason: {e}")
|
||||
log.exception(f"Failed to delete {file_path}. Reason: {e}")
|
||||
else:
|
||||
print(f"Directory {UPLOAD_DIR} not found in local storage.")
|
||||
log.warning(f"Directory {UPLOAD_DIR} not found in local storage.")
|
||||
|
||||
|
||||
class S3StorageProvider(StorageProvider):
|
||||
def __init__(self):
|
||||
self.s3_client = boto3.client(
|
||||
"s3",
|
||||
region_name=S3_REGION_NAME,
|
||||
endpoint_url=S3_ENDPOINT_URL,
|
||||
aws_access_key_id=S3_ACCESS_KEY_ID,
|
||||
aws_secret_access_key=S3_SECRET_ACCESS_KEY,
|
||||
config = Config(
|
||||
s3={
|
||||
"use_accelerate_endpoint": S3_USE_ACCELERATE_ENDPOINT,
|
||||
"addressing_style": S3_ADDRESSING_STYLE,
|
||||
},
|
||||
)
|
||||
|
||||
# If access key and secret are provided, use them for authentication
|
||||
if S3_ACCESS_KEY_ID and S3_SECRET_ACCESS_KEY:
|
||||
self.s3_client = boto3.client(
|
||||
"s3",
|
||||
region_name=S3_REGION_NAME,
|
||||
endpoint_url=S3_ENDPOINT_URL,
|
||||
aws_access_key_id=S3_ACCESS_KEY_ID,
|
||||
aws_secret_access_key=S3_SECRET_ACCESS_KEY,
|
||||
config=config,
|
||||
)
|
||||
else:
|
||||
# If no explicit credentials are provided, fall back to default AWS credentials
|
||||
# This supports workload identity (IAM roles for EC2, EKS, etc.)
|
||||
self.s3_client = boto3.client(
|
||||
"s3",
|
||||
region_name=S3_REGION_NAME,
|
||||
endpoint_url=S3_ENDPOINT_URL,
|
||||
config=config,
|
||||
)
|
||||
|
||||
self.bucket_name = S3_BUCKET_NAME
|
||||
self.key_prefix = S3_KEY_PREFIX if S3_KEY_PREFIX else ""
|
||||
|
||||
def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
|
||||
"""Handles uploading of the file to S3 storage."""
|
||||
_, file_path = LocalStorageProvider.upload_file(file, filename)
|
||||
try:
|
||||
self.s3_client.upload_file(file_path, self.bucket_name, filename)
|
||||
s3_key = os.path.join(self.key_prefix, filename)
|
||||
self.s3_client.upload_file(file_path, self.bucket_name, s3_key)
|
||||
return (
|
||||
open(file_path, "rb").read(),
|
||||
"s3://" + self.bucket_name + "/" + filename,
|
||||
"s3://" + self.bucket_name + "/" + s3_key,
|
||||
)
|
||||
except ClientError as e:
|
||||
raise RuntimeError(f"Error uploading file to S3: {e}")
|
||||
|
|
@ -109,18 +147,18 @@ class S3StorageProvider(StorageProvider):
|
|||
def get_file(self, file_path: str) -> str:
|
||||
"""Handles downloading of the file from S3 storage."""
|
||||
try:
|
||||
bucket_name, key = file_path.split("//")[1].split("/")
|
||||
local_file_path = f"{UPLOAD_DIR}/{key}"
|
||||
self.s3_client.download_file(bucket_name, key, local_file_path)
|
||||
s3_key = self._extract_s3_key(file_path)
|
||||
local_file_path = self._get_local_file_path(s3_key)
|
||||
self.s3_client.download_file(self.bucket_name, s3_key, local_file_path)
|
||||
return local_file_path
|
||||
except ClientError as e:
|
||||
raise RuntimeError(f"Error downloading file from S3: {e}")
|
||||
|
||||
def delete_file(self, file_path: str) -> None:
|
||||
"""Handles deletion of the file from S3 storage."""
|
||||
filename = file_path.split("/")[-1]
|
||||
try:
|
||||
self.s3_client.delete_object(Bucket=self.bucket_name, Key=filename)
|
||||
s3_key = self._extract_s3_key(file_path)
|
||||
self.s3_client.delete_object(Bucket=self.bucket_name, Key=s3_key)
|
||||
except ClientError as e:
|
||||
raise RuntimeError(f"Error deleting file from S3: {e}")
|
||||
|
||||
|
|
@ -133,6 +171,10 @@ class S3StorageProvider(StorageProvider):
|
|||
response = self.s3_client.list_objects_v2(Bucket=self.bucket_name)
|
||||
if "Contents" in response:
|
||||
for content in response["Contents"]:
|
||||
# Skip objects that were not uploaded from open-webui in the first place
|
||||
if not content["Key"].startswith(self.key_prefix):
|
||||
continue
|
||||
|
||||
self.s3_client.delete_object(
|
||||
Bucket=self.bucket_name, Key=content["Key"]
|
||||
)
|
||||
|
|
@ -142,6 +184,13 @@ class S3StorageProvider(StorageProvider):
|
|||
# Always delete from local storage
|
||||
LocalStorageProvider.delete_all_files()
|
||||
|
||||
# The s3 key is the name assigned to an object. It excludes the bucket name, but includes the internal path and the file name.
|
||||
def _extract_s3_key(self, full_file_path: str) -> str:
|
||||
return "/".join(full_file_path.split("//")[1].split("/")[1:])
|
||||
|
||||
def _get_local_file_path(self, s3_key: str) -> str:
|
||||
return f"{UPLOAD_DIR}/{s3_key.split('/')[-1]}"
|
||||
|
||||
|
||||
class GCSStorageProvider(StorageProvider):
|
||||
def __init__(self):
|
||||
|
|
@ -207,6 +256,74 @@ class GCSStorageProvider(StorageProvider):
|
|||
LocalStorageProvider.delete_all_files()
|
||||
|
||||
|
||||
class AzureStorageProvider(StorageProvider):
|
||||
def __init__(self):
|
||||
self.endpoint = AZURE_STORAGE_ENDPOINT
|
||||
self.container_name = AZURE_STORAGE_CONTAINER_NAME
|
||||
storage_key = AZURE_STORAGE_KEY
|
||||
|
||||
if storage_key:
|
||||
# Configure using the Azure Storage Account Endpoint and Key
|
||||
self.blob_service_client = BlobServiceClient(
|
||||
account_url=self.endpoint, credential=storage_key
|
||||
)
|
||||
else:
|
||||
# Configure using the Azure Storage Account Endpoint and DefaultAzureCredential
|
||||
# If the key is not configured, then the DefaultAzureCredential will be used to support Managed Identity authentication
|
||||
self.blob_service_client = BlobServiceClient(
|
||||
account_url=self.endpoint, credential=DefaultAzureCredential()
|
||||
)
|
||||
self.container_client = self.blob_service_client.get_container_client(
|
||||
self.container_name
|
||||
)
|
||||
|
||||
def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
|
||||
"""Handles uploading of the file to Azure Blob Storage."""
|
||||
contents, file_path = LocalStorageProvider.upload_file(file, filename)
|
||||
try:
|
||||
blob_client = self.container_client.get_blob_client(filename)
|
||||
blob_client.upload_blob(contents, overwrite=True)
|
||||
return contents, f"{self.endpoint}/{self.container_name}/{filename}"
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error uploading file to Azure Blob Storage: {e}")
|
||||
|
||||
def get_file(self, file_path: str) -> str:
|
||||
"""Handles downloading of the file from Azure Blob Storage."""
|
||||
try:
|
||||
filename = file_path.split("/")[-1]
|
||||
local_file_path = f"{UPLOAD_DIR}/{filename}"
|
||||
blob_client = self.container_client.get_blob_client(filename)
|
||||
with open(local_file_path, "wb") as download_file:
|
||||
download_file.write(blob_client.download_blob().readall())
|
||||
return local_file_path
|
||||
except ResourceNotFoundError as e:
|
||||
raise RuntimeError(f"Error downloading file from Azure Blob Storage: {e}")
|
||||
|
||||
def delete_file(self, file_path: str) -> None:
|
||||
"""Handles deletion of the file from Azure Blob Storage."""
|
||||
try:
|
||||
filename = file_path.split("/")[-1]
|
||||
blob_client = self.container_client.get_blob_client(filename)
|
||||
blob_client.delete_blob()
|
||||
except ResourceNotFoundError as e:
|
||||
raise RuntimeError(f"Error deleting file from Azure Blob Storage: {e}")
|
||||
|
||||
# Always delete from local storage
|
||||
LocalStorageProvider.delete_file(file_path)
|
||||
|
||||
def delete_all_files(self) -> None:
|
||||
"""Handles deletion of all files from Azure Blob Storage."""
|
||||
try:
|
||||
blobs = self.container_client.list_blobs()
|
||||
for blob in blobs:
|
||||
self.container_client.delete_blob(blob.name)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error deleting all files from Azure Blob Storage: {e}")
|
||||
|
||||
# Always delete from local storage
|
||||
LocalStorageProvider.delete_all_files()
|
||||
|
||||
|
||||
def get_storage_provider(storage_provider: str):
|
||||
if storage_provider == "local":
|
||||
Storage = LocalStorageProvider()
|
||||
|
|
@ -214,6 +331,8 @@ def get_storage_provider(storage_provider: str):
|
|||
Storage = S3StorageProvider()
|
||||
elif storage_provider == "gcs":
|
||||
Storage = GCSStorageProvider()
|
||||
elif storage_provider == "azure":
|
||||
Storage = AzureStorageProvider()
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported storage provider: {storage_provider}")
|
||||
return Storage
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ from moto import mock_aws
|
|||
from open_webui.storage import provider
|
||||
from gcp_storage_emulator.server import create_server
|
||||
from google.cloud import storage
|
||||
from azure.storage.blob import BlobServiceClient, ContainerClient, BlobClient
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
|
||||
def mock_upload_dir(monkeypatch, tmp_path):
|
||||
|
|
@ -22,6 +24,7 @@ def test_imports():
|
|||
provider.LocalStorageProvider
|
||||
provider.S3StorageProvider
|
||||
provider.GCSStorageProvider
|
||||
provider.AzureStorageProvider
|
||||
provider.Storage
|
||||
|
||||
|
||||
|
|
@ -32,6 +35,8 @@ def test_get_storage_provider():
|
|||
assert isinstance(Storage, provider.S3StorageProvider)
|
||||
Storage = provider.get_storage_provider("gcs")
|
||||
assert isinstance(Storage, provider.GCSStorageProvider)
|
||||
Storage = provider.get_storage_provider("azure")
|
||||
assert isinstance(Storage, provider.AzureStorageProvider)
|
||||
with pytest.raises(RuntimeError):
|
||||
provider.get_storage_provider("invalid")
|
||||
|
||||
|
|
@ -48,6 +53,7 @@ def test_class_instantiation():
|
|||
provider.LocalStorageProvider()
|
||||
provider.S3StorageProvider()
|
||||
provider.GCSStorageProvider()
|
||||
provider.AzureStorageProvider()
|
||||
|
||||
|
||||
class TestLocalStorageProvider:
|
||||
|
|
@ -181,6 +187,17 @@ class TestS3StorageProvider:
|
|||
assert not (upload_dir / self.filename).exists()
|
||||
assert not (upload_dir / self.filename_extra).exists()
|
||||
|
||||
def test_init_without_credentials(self, monkeypatch):
|
||||
"""Test that S3StorageProvider can initialize without explicit credentials."""
|
||||
# Temporarily unset the environment variables
|
||||
monkeypatch.setattr(provider, "S3_ACCESS_KEY_ID", None)
|
||||
monkeypatch.setattr(provider, "S3_SECRET_ACCESS_KEY", None)
|
||||
|
||||
# Should not raise an exception
|
||||
storage = provider.S3StorageProvider()
|
||||
assert storage.s3_client is not None
|
||||
assert storage.bucket_name == provider.S3_BUCKET_NAME
|
||||
|
||||
|
||||
class TestGCSStorageProvider:
|
||||
Storage = provider.GCSStorageProvider()
|
||||
|
|
@ -272,3 +289,147 @@ class TestGCSStorageProvider:
|
|||
assert not (upload_dir / self.filename_extra).exists()
|
||||
assert self.Storage.bucket.get_blob(self.filename) == None
|
||||
assert self.Storage.bucket.get_blob(self.filename_extra) == None
|
||||
|
||||
|
||||
class TestAzureStorageProvider:
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def setup_storage(self, monkeypatch):
|
||||
# Create mock Blob Service Client and related clients
|
||||
mock_blob_service_client = MagicMock()
|
||||
mock_container_client = MagicMock()
|
||||
mock_blob_client = MagicMock()
|
||||
|
||||
# Set up return values for the mock
|
||||
mock_blob_service_client.get_container_client.return_value = (
|
||||
mock_container_client
|
||||
)
|
||||
mock_container_client.get_blob_client.return_value = mock_blob_client
|
||||
|
||||
# Monkeypatch the Azure classes to return our mocks
|
||||
monkeypatch.setattr(
|
||||
azure.storage.blob,
|
||||
"BlobServiceClient",
|
||||
lambda *args, **kwargs: mock_blob_service_client,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
azure.storage.blob,
|
||||
"ContainerClient",
|
||||
lambda *args, **kwargs: mock_container_client,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
azure.storage.blob, "BlobClient", lambda *args, **kwargs: mock_blob_client
|
||||
)
|
||||
|
||||
self.Storage = provider.AzureStorageProvider()
|
||||
self.Storage.endpoint = "https://myaccount.blob.core.windows.net"
|
||||
self.Storage.container_name = "my-container"
|
||||
self.file_content = b"test content"
|
||||
self.filename = "test.txt"
|
||||
self.filename_extra = "test_extra.txt"
|
||||
self.file_bytesio_empty = io.BytesIO()
|
||||
|
||||
# Apply mocks to the Storage instance
|
||||
self.Storage.blob_service_client = mock_blob_service_client
|
||||
self.Storage.container_client = mock_container_client
|
||||
|
||||
def test_upload_file(self, monkeypatch, tmp_path):
|
||||
upload_dir = mock_upload_dir(monkeypatch, tmp_path)
|
||||
|
||||
# Simulate an error when container does not exist
|
||||
self.Storage.container_client.get_blob_client.side_effect = Exception(
|
||||
"Container does not exist"
|
||||
)
|
||||
with pytest.raises(Exception):
|
||||
self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
|
||||
|
||||
# Reset side effect and create container
|
||||
self.Storage.container_client.get_blob_client.side_effect = None
|
||||
self.Storage.create_container()
|
||||
contents, azure_file_path = self.Storage.upload_file(
|
||||
io.BytesIO(self.file_content), self.filename
|
||||
)
|
||||
|
||||
# Assertions
|
||||
self.Storage.container_client.get_blob_client.assert_called_with(self.filename)
|
||||
self.Storage.container_client.get_blob_client().upload_blob.assert_called_once_with(
|
||||
self.file_content, overwrite=True
|
||||
)
|
||||
assert contents == self.file_content
|
||||
assert (
|
||||
azure_file_path
|
||||
== f"https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}"
|
||||
)
|
||||
assert (upload_dir / self.filename).exists()
|
||||
assert (upload_dir / self.filename).read_bytes() == self.file_content
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
self.Storage.upload_file(self.file_bytesio_empty, self.filename)
|
||||
|
||||
def test_get_file(self, monkeypatch, tmp_path):
|
||||
upload_dir = mock_upload_dir(monkeypatch, tmp_path)
|
||||
self.Storage.create_container()
|
||||
|
||||
# Mock upload behavior
|
||||
self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
|
||||
# Mock blob download behavior
|
||||
self.Storage.container_client.get_blob_client().download_blob().readall.return_value = (
|
||||
self.file_content
|
||||
)
|
||||
|
||||
file_url = f"https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}"
|
||||
file_path = self.Storage.get_file(file_url)
|
||||
|
||||
assert file_path == str(upload_dir / self.filename)
|
||||
assert (upload_dir / self.filename).exists()
|
||||
assert (upload_dir / self.filename).read_bytes() == self.file_content
|
||||
|
||||
def test_delete_file(self, monkeypatch, tmp_path):
|
||||
upload_dir = mock_upload_dir(monkeypatch, tmp_path)
|
||||
self.Storage.create_container()
|
||||
|
||||
# Mock file upload
|
||||
self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
|
||||
# Mock deletion
|
||||
self.Storage.container_client.get_blob_client().delete_blob.return_value = None
|
||||
|
||||
file_url = f"https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}"
|
||||
self.Storage.delete_file(file_url)
|
||||
|
||||
self.Storage.container_client.get_blob_client().delete_blob.assert_called_once()
|
||||
assert not (upload_dir / self.filename).exists()
|
||||
|
||||
def test_delete_all_files(self, monkeypatch, tmp_path):
|
||||
upload_dir = mock_upload_dir(monkeypatch, tmp_path)
|
||||
self.Storage.create_container()
|
||||
|
||||
# Mock file uploads
|
||||
self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
|
||||
self.Storage.upload_file(io.BytesIO(self.file_content), self.filename_extra)
|
||||
|
||||
# Mock listing and deletion behavior
|
||||
self.Storage.container_client.list_blobs.return_value = [
|
||||
{"name": self.filename},
|
||||
{"name": self.filename_extra},
|
||||
]
|
||||
self.Storage.container_client.get_blob_client().delete_blob.return_value = None
|
||||
|
||||
self.Storage.delete_all_files()
|
||||
|
||||
self.Storage.container_client.list_blobs.assert_called_once()
|
||||
self.Storage.container_client.get_blob_client().delete_blob.assert_any_call()
|
||||
assert not (upload_dir / self.filename).exists()
|
||||
assert not (upload_dir / self.filename_extra).exists()
|
||||
|
||||
def test_get_file_not_found(self, monkeypatch):
|
||||
self.Storage.create_container()
|
||||
|
||||
file_url = f"https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}"
|
||||
# Mock behavior to raise an error for missing blobs
|
||||
self.Storage.container_client.get_blob_client().download_blob.side_effect = (
|
||||
Exception("Blob not found")
|
||||
)
|
||||
with pytest.raises(Exception, match="Blob not found"):
|
||||
self.Storage.get_file(file_url)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,249 @@
|
|||
from contextlib import asynccontextmanager
|
||||
from dataclasses import asdict, dataclass
|
||||
from enum import Enum
|
||||
import re
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Dict,
|
||||
MutableMapping,
|
||||
Optional,
|
||||
cast,
|
||||
)
|
||||
import uuid
|
||||
|
||||
from asgiref.typing import (
|
||||
ASGI3Application,
|
||||
ASGIReceiveCallable,
|
||||
ASGIReceiveEvent,
|
||||
ASGISendCallable,
|
||||
ASGISendEvent,
|
||||
Scope as ASGIScope,
|
||||
)
|
||||
from loguru import logger
|
||||
from starlette.requests import Request
|
||||
|
||||
from open_webui.env import AUDIT_LOG_LEVEL, MAX_BODY_LOG_SIZE
|
||||
from open_webui.utils.auth import get_current_user, get_http_authorization_cred
|
||||
from open_webui.models.users import UserModel
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from loguru import Logger
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AuditLogEntry:
|
||||
# `Metadata` audit level properties
|
||||
id: str
|
||||
user: dict[str, Any]
|
||||
audit_level: str
|
||||
verb: str
|
||||
request_uri: str
|
||||
user_agent: Optional[str] = None
|
||||
source_ip: Optional[str] = None
|
||||
# `Request` audit level properties
|
||||
request_object: Any = None
|
||||
# `Request Response` level
|
||||
response_object: Any = None
|
||||
response_status_code: Optional[int] = None
|
||||
|
||||
|
||||
class AuditLevel(str, Enum):
|
||||
NONE = "NONE"
|
||||
METADATA = "METADATA"
|
||||
REQUEST = "REQUEST"
|
||||
REQUEST_RESPONSE = "REQUEST_RESPONSE"
|
||||
|
||||
|
||||
class AuditLogger:
|
||||
"""
|
||||
A helper class that encapsulates audit logging functionality. It uses Loguru’s logger with an auditable binding to ensure that audit log entries are filtered correctly.
|
||||
|
||||
Parameters:
|
||||
logger (Logger): An instance of Loguru’s logger.
|
||||
"""
|
||||
|
||||
def __init__(self, logger: "Logger"):
|
||||
self.logger = logger.bind(auditable=True)
|
||||
|
||||
def write(
|
||||
self,
|
||||
audit_entry: AuditLogEntry,
|
||||
*,
|
||||
log_level: str = "INFO",
|
||||
extra: Optional[dict] = None,
|
||||
):
|
||||
|
||||
entry = asdict(audit_entry)
|
||||
|
||||
if extra:
|
||||
entry["extra"] = extra
|
||||
|
||||
self.logger.log(
|
||||
log_level,
|
||||
"",
|
||||
**entry,
|
||||
)
|
||||
|
||||
|
||||
class AuditContext:
|
||||
"""
|
||||
Captures and aggregates the HTTP request and response bodies during the processing of a request. It ensures that only a configurable maximum amount of data is stored to prevent excessive memory usage.
|
||||
|
||||
Attributes:
|
||||
request_body (bytearray): Accumulated request payload.
|
||||
response_body (bytearray): Accumulated response payload.
|
||||
max_body_size (int): Maximum number of bytes to capture.
|
||||
metadata (Dict[str, Any]): A dictionary to store additional audit metadata (user, http verb, user agent, etc.).
|
||||
"""
|
||||
|
||||
def __init__(self, max_body_size: int = MAX_BODY_LOG_SIZE):
|
||||
self.request_body = bytearray()
|
||||
self.response_body = bytearray()
|
||||
self.max_body_size = max_body_size
|
||||
self.metadata: Dict[str, Any] = {}
|
||||
|
||||
def add_request_chunk(self, chunk: bytes):
|
||||
if len(self.request_body) < self.max_body_size:
|
||||
self.request_body.extend(
|
||||
chunk[: self.max_body_size - len(self.request_body)]
|
||||
)
|
||||
|
||||
def add_response_chunk(self, chunk: bytes):
|
||||
if len(self.response_body) < self.max_body_size:
|
||||
self.response_body.extend(
|
||||
chunk[: self.max_body_size - len(self.response_body)]
|
||||
)
|
||||
|
||||
|
||||
class AuditLoggingMiddleware:
|
||||
"""
|
||||
ASGI middleware that intercepts HTTP requests and responses to perform audit logging. It captures request/response bodies (depending on audit level), headers, HTTP methods, and user information, then logs a structured audit entry at the end of the request cycle.
|
||||
"""
|
||||
|
||||
AUDITED_METHODS = {"PUT", "PATCH", "DELETE", "POST"}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGI3Application,
|
||||
*,
|
||||
excluded_paths: Optional[list[str]] = None,
|
||||
max_body_size: int = MAX_BODY_LOG_SIZE,
|
||||
audit_level: AuditLevel = AuditLevel.NONE,
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.audit_logger = AuditLogger(logger)
|
||||
self.excluded_paths = excluded_paths or []
|
||||
self.max_body_size = max_body_size
|
||||
self.audit_level = audit_level
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
scope: ASGIScope,
|
||||
receive: ASGIReceiveCallable,
|
||||
send: ASGISendCallable,
|
||||
) -> None:
|
||||
if scope["type"] != "http":
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
request = Request(scope=cast(MutableMapping, scope))
|
||||
|
||||
if self._should_skip_auditing(request):
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
async with self._audit_context(request) as context:
|
||||
|
||||
async def send_wrapper(message: ASGISendEvent) -> None:
|
||||
if self.audit_level == AuditLevel.REQUEST_RESPONSE:
|
||||
await self._capture_response(message, context)
|
||||
|
||||
await send(message)
|
||||
|
||||
original_receive = receive
|
||||
|
||||
async def receive_wrapper() -> ASGIReceiveEvent:
|
||||
nonlocal original_receive
|
||||
message = await original_receive()
|
||||
|
||||
if self.audit_level in (
|
||||
AuditLevel.REQUEST,
|
||||
AuditLevel.REQUEST_RESPONSE,
|
||||
):
|
||||
await self._capture_request(message, context)
|
||||
|
||||
return message
|
||||
|
||||
await self.app(scope, receive_wrapper, send_wrapper)
|
||||
|
||||
@asynccontextmanager
|
||||
async def _audit_context(
|
||||
self, request: Request
|
||||
) -> AsyncGenerator[AuditContext, None]:
|
||||
"""
|
||||
async context manager that ensures that an audit log entry is recorded after the request is processed.
|
||||
"""
|
||||
context = AuditContext()
|
||||
try:
|
||||
yield context
|
||||
finally:
|
||||
await self._log_audit_entry(request, context)
|
||||
|
||||
async def _get_authenticated_user(self, request: Request) -> UserModel:
|
||||
|
||||
auth_header = request.headers.get("Authorization")
|
||||
assert auth_header
|
||||
user = get_current_user(request, None, get_http_authorization_cred(auth_header))
|
||||
|
||||
return user
|
||||
|
||||
def _should_skip_auditing(self, request: Request) -> bool:
|
||||
if (
|
||||
request.method not in {"POST", "PUT", "PATCH", "DELETE"}
|
||||
or AUDIT_LOG_LEVEL == "NONE"
|
||||
or not request.headers.get("authorization")
|
||||
):
|
||||
return True
|
||||
# match either /api/<resource>/...(for the endpoint /api/chat case) or /api/v1/<resource>/...
|
||||
pattern = re.compile(
|
||||
r"^/api(?:/v1)?/(" + "|".join(self.excluded_paths) + r")\b"
|
||||
)
|
||||
if pattern.match(request.url.path):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def _capture_request(self, message: ASGIReceiveEvent, context: AuditContext):
|
||||
if message["type"] == "http.request":
|
||||
body = message.get("body", b"")
|
||||
context.add_request_chunk(body)
|
||||
|
||||
async def _capture_response(self, message: ASGISendEvent, context: AuditContext):
|
||||
if message["type"] == "http.response.start":
|
||||
context.metadata["response_status_code"] = message["status"]
|
||||
|
||||
elif message["type"] == "http.response.body":
|
||||
body = message.get("body", b"")
|
||||
context.add_response_chunk(body)
|
||||
|
||||
async def _log_audit_entry(self, request: Request, context: AuditContext):
|
||||
try:
|
||||
user = await self._get_authenticated_user(request)
|
||||
|
||||
entry = AuditLogEntry(
|
||||
id=str(uuid.uuid4()),
|
||||
user=user.model_dump(include={"id", "name", "email", "role"}),
|
||||
audit_level=self.audit_level.value,
|
||||
verb=request.method,
|
||||
request_uri=str(request.url),
|
||||
response_status_code=context.metadata.get("response_status_code", None),
|
||||
source_ip=request.client.host if request.client else None,
|
||||
user_agent=request.headers.get("user-agent"),
|
||||
request_object=context.request_body.decode("utf-8", errors="replace"),
|
||||
response_object=context.response_body.decode("utf-8", errors="replace"),
|
||||
)
|
||||
|
||||
self.audit_logger.write(entry)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to log audit entry: {str(e)}")
|
||||
|
|
@ -1,6 +1,12 @@
|
|||
import logging
|
||||
import uuid
|
||||
import jwt
|
||||
import base64
|
||||
import hmac
|
||||
import hashlib
|
||||
import requests
|
||||
import os
|
||||
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Optional, Union, List, Dict
|
||||
|
|
@ -8,14 +14,22 @@ from typing import Optional, Union, List, Dict
|
|||
from open_webui.models.users import Users
|
||||
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import WEBUI_SECRET_KEY
|
||||
from open_webui.env import (
|
||||
WEBUI_SECRET_KEY,
|
||||
TRUSTED_SIGNATURE_KEY,
|
||||
STATIC_DIR,
|
||||
SRC_LOG_LEVELS,
|
||||
)
|
||||
|
||||
from fastapi import Depends, HTTPException, Request, Response, status
|
||||
from fastapi import BackgroundTasks, Depends, HTTPException, Request, Response, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from passlib.context import CryptContext
|
||||
|
||||
|
||||
logging.getLogger("passlib").setLevel(logging.ERROR)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["OAUTH"])
|
||||
|
||||
SESSION_SECRET = WEBUI_SECRET_KEY
|
||||
ALGORITHM = "HS256"
|
||||
|
|
@ -24,6 +38,67 @@ ALGORITHM = "HS256"
|
|||
# Auth Utils
|
||||
##############
|
||||
|
||||
|
||||
def verify_signature(payload: str, signature: str) -> bool:
|
||||
"""
|
||||
Verifies the HMAC signature of the received payload.
|
||||
"""
|
||||
try:
|
||||
expected_signature = base64.b64encode(
|
||||
hmac.new(TRUSTED_SIGNATURE_KEY, payload.encode(), hashlib.sha256).digest()
|
||||
).decode()
|
||||
|
||||
# Compare securely to prevent timing attacks
|
||||
return hmac.compare_digest(expected_signature, signature)
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def override_static(path: str, content: str):
|
||||
# Ensure path is safe
|
||||
if "/" in path or ".." in path:
|
||||
log.error(f"Invalid path: {path}")
|
||||
return
|
||||
|
||||
file_path = os.path.join(STATIC_DIR, path)
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(base64.b64decode(content)) # Convert Base64 back to raw binary
|
||||
|
||||
|
||||
def get_license_data(app, key):
|
||||
if key:
|
||||
try:
|
||||
res = requests.post(
|
||||
"https://api.openwebui.com/api/v1/license",
|
||||
json={"key": key, "version": "1"},
|
||||
timeout=5,
|
||||
)
|
||||
|
||||
if getattr(res, "ok", False):
|
||||
payload = getattr(res, "json", lambda: {})()
|
||||
for k, v in payload.items():
|
||||
if k == "resources":
|
||||
for p, c in v.items():
|
||||
globals().get("override_static", lambda a, b: None)(p, c)
|
||||
elif k == "count":
|
||||
setattr(app.state, "USER_COUNT", v)
|
||||
elif k == "name":
|
||||
setattr(app.state, "WEBUI_NAME", v)
|
||||
elif k == "metadata":
|
||||
setattr(app.state, "LICENSE_METADATA", v)
|
||||
return True
|
||||
else:
|
||||
log.error(
|
||||
f"License: retrieval issue: {getattr(res, 'text', 'unknown error')}"
|
||||
)
|
||||
except Exception as ex:
|
||||
log.exception(f"License: Uncaught Exception: {ex}")
|
||||
return False
|
||||
|
||||
|
||||
bearer_security = HTTPBearer(auto_error=False)
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
|
|
@ -76,6 +151,7 @@ def get_http_authorization_cred(auth_header: str):
|
|||
|
||||
def get_current_user(
|
||||
request: Request,
|
||||
background_tasks: BackgroundTasks,
|
||||
auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
|
||||
):
|
||||
token = None
|
||||
|
|
@ -128,7 +204,10 @@ def get_current_user(
|
|||
detail=ERROR_MESSAGES.INVALID_TOKEN,
|
||||
)
|
||||
else:
|
||||
Users.update_user_last_active_by_id(user.id)
|
||||
# Refresh the user's last active timestamp asynchronously
|
||||
# to prevent blocking the request
|
||||
if background_tasks:
|
||||
background_tasks.add_task(Users.update_user_last_active_by_id, user.id)
|
||||
return user
|
||||
else:
|
||||
raise HTTPException(
|
||||
|
|
|
|||
|
|
@ -7,14 +7,17 @@ from typing import Any, Optional
|
|||
import random
|
||||
import json
|
||||
import inspect
|
||||
import uuid
|
||||
import asyncio
|
||||
|
||||
from fastapi import Request
|
||||
from starlette.responses import Response, StreamingResponse
|
||||
from fastapi import Request, status
|
||||
from starlette.responses import Response, StreamingResponse, JSONResponse
|
||||
|
||||
|
||||
from open_webui.models.users import UserModel
|
||||
|
||||
from open_webui.socket.main import (
|
||||
sio,
|
||||
get_event_call,
|
||||
get_event_emitter,
|
||||
)
|
||||
|
|
@ -44,6 +47,10 @@ from open_webui.utils.response import (
|
|||
convert_response_ollama_to_openai,
|
||||
convert_streaming_response_ollama_to_openai,
|
||||
)
|
||||
from open_webui.utils.filter import (
|
||||
get_sorted_filter_ids,
|
||||
process_filter_functions,
|
||||
)
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL
|
||||
|
||||
|
|
@ -53,108 +60,224 @@ log = logging.getLogger(__name__)
|
|||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
|
||||
async def generate_direct_chat_completion(
|
||||
request: Request,
|
||||
form_data: dict,
|
||||
user: Any,
|
||||
models: dict,
|
||||
):
|
||||
log.info("generate_direct_chat_completion")
|
||||
|
||||
metadata = form_data.pop("metadata", {})
|
||||
|
||||
user_id = metadata.get("user_id")
|
||||
session_id = metadata.get("session_id")
|
||||
request_id = str(uuid.uuid4()) # Generate a unique request ID
|
||||
|
||||
event_caller = get_event_call(metadata)
|
||||
|
||||
channel = f"{user_id}:{session_id}:{request_id}"
|
||||
|
||||
if form_data.get("stream"):
|
||||
q = asyncio.Queue()
|
||||
|
||||
async def message_listener(sid, data):
|
||||
"""
|
||||
Handle received socket messages and push them into the queue.
|
||||
"""
|
||||
await q.put(data)
|
||||
|
||||
# Register the listener
|
||||
sio.on(channel, message_listener)
|
||||
|
||||
# Start processing chat completion in background
|
||||
res = await event_caller(
|
||||
{
|
||||
"type": "request:chat:completion",
|
||||
"data": {
|
||||
"form_data": form_data,
|
||||
"model": models[form_data["model"]],
|
||||
"channel": channel,
|
||||
"session_id": session_id,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
log.info(f"res: {res}")
|
||||
|
||||
if res.get("status", False):
|
||||
# Define a generator to stream responses
|
||||
async def event_generator():
|
||||
nonlocal q
|
||||
try:
|
||||
while True:
|
||||
data = await q.get() # Wait for new messages
|
||||
if isinstance(data, dict):
|
||||
if "done" in data and data["done"]:
|
||||
break # Stop streaming when 'done' is received
|
||||
|
||||
yield f"data: {json.dumps(data)}\n\n"
|
||||
elif isinstance(data, str):
|
||||
yield data
|
||||
except Exception as e:
|
||||
log.debug(f"Error in event generator: {e}")
|
||||
pass
|
||||
|
||||
# Define a background task to run the event generator
|
||||
async def background():
|
||||
try:
|
||||
del sio.handlers["/"][channel]
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
# Return the streaming response
|
||||
return StreamingResponse(
|
||||
event_generator(), media_type="text/event-stream", background=background
|
||||
)
|
||||
else:
|
||||
raise Exception(str(res))
|
||||
else:
|
||||
res = await event_caller(
|
||||
{
|
||||
"type": "request:chat:completion",
|
||||
"data": {
|
||||
"form_data": form_data,
|
||||
"model": models[form_data["model"]],
|
||||
"channel": channel,
|
||||
"session_id": session_id,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
if "error" in res and res["error"]:
|
||||
raise Exception(res["error"])
|
||||
|
||||
return res
|
||||
|
||||
|
||||
async def generate_chat_completion(
|
||||
request: Request,
|
||||
form_data: dict,
|
||||
user: Any,
|
||||
bypass_filter: bool = False,
|
||||
):
|
||||
log.debug(f"generate_chat_completion: {form_data}")
|
||||
if BYPASS_MODEL_ACCESS_CONTROL:
|
||||
bypass_filter = True
|
||||
|
||||
models = request.app.state.MODELS
|
||||
if hasattr(request.state, "metadata"):
|
||||
if "metadata" not in form_data:
|
||||
form_data["metadata"] = request.state.metadata
|
||||
else:
|
||||
form_data["metadata"] = {
|
||||
**form_data["metadata"],
|
||||
**request.state.metadata,
|
||||
}
|
||||
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
models = {
|
||||
request.state.model["id"]: request.state.model,
|
||||
}
|
||||
log.debug(f"direct connection to model: {models}")
|
||||
else:
|
||||
models = request.app.state.MODELS
|
||||
|
||||
model_id = form_data["model"]
|
||||
if model_id not in models:
|
||||
raise Exception("Model not found")
|
||||
|
||||
# Process the form_data through the pipeline
|
||||
try:
|
||||
form_data = process_pipeline_inlet_filter(request, form_data, user, models)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
model = models[model_id]
|
||||
|
||||
# Check if user has access to the model
|
||||
if not bypass_filter and user.role == "user":
|
||||
try:
|
||||
check_model_access(user, model)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
if model["owned_by"] == "arena":
|
||||
model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
|
||||
filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
|
||||
if model_ids and filter_mode == "exclude":
|
||||
model_ids = [
|
||||
model["id"]
|
||||
for model in list(request.app.state.MODELS.values())
|
||||
if model.get("owned_by") != "arena" and model["id"] not in model_ids
|
||||
]
|
||||
|
||||
selected_model_id = None
|
||||
if isinstance(model_ids, list) and model_ids:
|
||||
selected_model_id = random.choice(model_ids)
|
||||
else:
|
||||
model_ids = [
|
||||
model["id"]
|
||||
for model in list(request.app.state.MODELS.values())
|
||||
if model.get("owned_by") != "arena"
|
||||
]
|
||||
selected_model_id = random.choice(model_ids)
|
||||
|
||||
form_data["model"] = selected_model_id
|
||||
|
||||
if form_data.get("stream") == True:
|
||||
|
||||
async def stream_wrapper(stream):
|
||||
yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n"
|
||||
async for chunk in stream:
|
||||
yield chunk
|
||||
|
||||
response = await generate_chat_completion(
|
||||
request, form_data, user, bypass_filter=True
|
||||
)
|
||||
return StreamingResponse(
|
||||
stream_wrapper(response.body_iterator),
|
||||
media_type="text/event-stream",
|
||||
background=response.background,
|
||||
)
|
||||
else:
|
||||
return {
|
||||
**(
|
||||
await generate_chat_completion(
|
||||
request, form_data, user, bypass_filter=True
|
||||
)
|
||||
),
|
||||
"selected_model_id": selected_model_id,
|
||||
}
|
||||
|
||||
if model.get("pipe"):
|
||||
# Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter
|
||||
return await generate_function_chat_completion(
|
||||
if getattr(request.state, "direct", False):
|
||||
return await generate_direct_chat_completion(
|
||||
request, form_data, user=user, models=models
|
||||
)
|
||||
if model["owned_by"] == "ollama":
|
||||
# Using /ollama/api/chat endpoint
|
||||
form_data = convert_payload_openai_to_ollama(form_data)
|
||||
response = await generate_ollama_chat_completion(
|
||||
request=request, form_data=form_data, user=user, bypass_filter=bypass_filter
|
||||
)
|
||||
if form_data.get("stream"):
|
||||
response.headers["content-type"] = "text/event-stream"
|
||||
return StreamingResponse(
|
||||
convert_streaming_response_ollama_to_openai(response),
|
||||
headers=dict(response.headers),
|
||||
background=response.background,
|
||||
)
|
||||
else:
|
||||
return convert_response_ollama_to_openai(response)
|
||||
else:
|
||||
return await generate_openai_chat_completion(
|
||||
request=request, form_data=form_data, user=user, bypass_filter=bypass_filter
|
||||
)
|
||||
# Check if user has access to the model
|
||||
if not bypass_filter and user.role == "user":
|
||||
try:
|
||||
check_model_access(user, model)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
if model.get("owned_by") == "arena":
|
||||
model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
|
||||
filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
|
||||
if model_ids and filter_mode == "exclude":
|
||||
model_ids = [
|
||||
model["id"]
|
||||
for model in list(request.app.state.MODELS.values())
|
||||
if model.get("owned_by") != "arena" and model["id"] not in model_ids
|
||||
]
|
||||
|
||||
selected_model_id = None
|
||||
if isinstance(model_ids, list) and model_ids:
|
||||
selected_model_id = random.choice(model_ids)
|
||||
else:
|
||||
model_ids = [
|
||||
model["id"]
|
||||
for model in list(request.app.state.MODELS.values())
|
||||
if model.get("owned_by") != "arena"
|
||||
]
|
||||
selected_model_id = random.choice(model_ids)
|
||||
|
||||
form_data["model"] = selected_model_id
|
||||
|
||||
if form_data.get("stream") == True:
|
||||
|
||||
async def stream_wrapper(stream):
|
||||
yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n"
|
||||
async for chunk in stream:
|
||||
yield chunk
|
||||
|
||||
response = await generate_chat_completion(
|
||||
request, form_data, user, bypass_filter=True
|
||||
)
|
||||
return StreamingResponse(
|
||||
stream_wrapper(response.body_iterator),
|
||||
media_type="text/event-stream",
|
||||
background=response.background,
|
||||
)
|
||||
else:
|
||||
return {
|
||||
**(
|
||||
await generate_chat_completion(
|
||||
request, form_data, user, bypass_filter=True
|
||||
)
|
||||
),
|
||||
"selected_model_id": selected_model_id,
|
||||
}
|
||||
|
||||
if model.get("pipe"):
|
||||
# Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter
|
||||
return await generate_function_chat_completion(
|
||||
request, form_data, user=user, models=models
|
||||
)
|
||||
if model.get("owned_by") == "ollama":
|
||||
# Using /ollama/api/chat endpoint
|
||||
form_data = convert_payload_openai_to_ollama(form_data)
|
||||
response = await generate_ollama_chat_completion(
|
||||
request=request,
|
||||
form_data=form_data,
|
||||
user=user,
|
||||
bypass_filter=bypass_filter,
|
||||
)
|
||||
if form_data.get("stream"):
|
||||
response.headers["content-type"] = "text/event-stream"
|
||||
return StreamingResponse(
|
||||
convert_streaming_response_ollama_to_openai(response),
|
||||
headers=dict(response.headers),
|
||||
background=response.background,
|
||||
)
|
||||
else:
|
||||
return convert_response_ollama_to_openai(response)
|
||||
else:
|
||||
return await generate_openai_chat_completion(
|
||||
request=request,
|
||||
form_data=form_data,
|
||||
user=user,
|
||||
bypass_filter=bypass_filter,
|
||||
)
|
||||
|
||||
|
||||
chat_completion = generate_chat_completion
|
||||
|
|
@ -162,8 +285,14 @@ chat_completion = generate_chat_completion
|
|||
|
||||
async def chat_completed(request: Request, form_data: dict, user: Any):
|
||||
if not request.app.state.MODELS:
|
||||
await get_all_models(request)
|
||||
models = request.app.state.MODELS
|
||||
await get_all_models(request, user=user)
|
||||
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
models = {
|
||||
request.state.model["id"]: request.state.model,
|
||||
}
|
||||
else:
|
||||
models = request.app.state.MODELS
|
||||
|
||||
data = form_data
|
||||
model_id = data["model"]
|
||||
|
|
@ -173,120 +302,47 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
|
|||
model = models[model_id]
|
||||
|
||||
try:
|
||||
data = process_pipeline_outlet_filter(request, data, user, models)
|
||||
data = await process_pipeline_outlet_filter(request, data, user, models)
|
||||
except Exception as e:
|
||||
return Exception(f"Error: {e}")
|
||||
|
||||
__event_emitter__ = get_event_emitter(
|
||||
{
|
||||
"chat_id": data["chat_id"],
|
||||
"message_id": data["id"],
|
||||
"session_id": data["session_id"],
|
||||
"user_id": user.id,
|
||||
}
|
||||
)
|
||||
metadata = {
|
||||
"chat_id": data["chat_id"],
|
||||
"message_id": data["id"],
|
||||
"session_id": data["session_id"],
|
||||
"user_id": user.id,
|
||||
}
|
||||
|
||||
__event_call__ = get_event_call(
|
||||
{
|
||||
"chat_id": data["chat_id"],
|
||||
"message_id": data["id"],
|
||||
"session_id": data["session_id"],
|
||||
"user_id": user.id,
|
||||
}
|
||||
)
|
||||
extra_params = {
|
||||
"__event_emitter__": get_event_emitter(metadata),
|
||||
"__event_call__": get_event_call(metadata),
|
||||
"__user__": {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
},
|
||||
"__metadata__": metadata,
|
||||
"__request__": request,
|
||||
"__model__": model,
|
||||
}
|
||||
|
||||
def get_priority(function_id):
|
||||
function = Functions.get_function_by_id(function_id)
|
||||
if function is not None and hasattr(function, "valves"):
|
||||
# TODO: Fix FunctionModel to include vavles
|
||||
return (function.valves if function.valves else {}).get("priority", 0)
|
||||
return 0
|
||||
try:
|
||||
filter_functions = [
|
||||
Functions.get_function_by_id(filter_id)
|
||||
for filter_id in get_sorted_filter_ids(model)
|
||||
]
|
||||
|
||||
filter_ids = [function.id for function in Functions.get_global_filter_functions()]
|
||||
if "info" in model and "meta" in model["info"]:
|
||||
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
|
||||
filter_ids = list(set(filter_ids))
|
||||
|
||||
enabled_filter_ids = [
|
||||
function.id
|
||||
for function in Functions.get_functions_by_type("filter", active_only=True)
|
||||
]
|
||||
filter_ids = [
|
||||
filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
|
||||
]
|
||||
|
||||
# Sort filter_ids by priority, using the get_priority function
|
||||
filter_ids.sort(key=get_priority)
|
||||
|
||||
for filter_id in filter_ids:
|
||||
filter = Functions.get_function_by_id(filter_id)
|
||||
if not filter:
|
||||
continue
|
||||
|
||||
if filter_id in request.app.state.FUNCTIONS:
|
||||
function_module = request.app.state.FUNCTIONS[filter_id]
|
||||
else:
|
||||
function_module, _, _ = load_function_module_by_id(filter_id)
|
||||
request.app.state.FUNCTIONS[filter_id] = function_module
|
||||
|
||||
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
|
||||
valves = Functions.get_function_valves_by_id(filter_id)
|
||||
function_module.valves = function_module.Valves(
|
||||
**(valves if valves else {})
|
||||
)
|
||||
|
||||
if not hasattr(function_module, "outlet"):
|
||||
continue
|
||||
try:
|
||||
outlet = function_module.outlet
|
||||
|
||||
# Get the signature of the function
|
||||
sig = inspect.signature(outlet)
|
||||
params = {"body": data}
|
||||
|
||||
# Extra parameters to be passed to the function
|
||||
extra_params = {
|
||||
"__model__": model,
|
||||
"__id__": filter_id,
|
||||
"__event_emitter__": __event_emitter__,
|
||||
"__event_call__": __event_call__,
|
||||
"__request__": request,
|
||||
}
|
||||
|
||||
# Add extra params in contained in function signature
|
||||
for key, value in extra_params.items():
|
||||
if key in sig.parameters:
|
||||
params[key] = value
|
||||
|
||||
if "__user__" in sig.parameters:
|
||||
__user__ = {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
}
|
||||
|
||||
try:
|
||||
if hasattr(function_module, "UserValves"):
|
||||
__user__["valves"] = function_module.UserValves(
|
||||
**Functions.get_user_valves_by_id_and_user_id(
|
||||
filter_id, user.id
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
params = {**params, "__user__": __user__}
|
||||
|
||||
if inspect.iscoroutinefunction(outlet):
|
||||
data = await outlet(**params)
|
||||
else:
|
||||
data = outlet(**params)
|
||||
|
||||
except Exception as e:
|
||||
return Exception(f"Error: {e}")
|
||||
|
||||
return data
|
||||
result, _ = await process_filter_functions(
|
||||
request=request,
|
||||
filter_functions=filter_functions,
|
||||
filter_type="outlet",
|
||||
form_data=data,
|
||||
extra_params=extra_params,
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
return Exception(f"Error: {e}")
|
||||
|
||||
|
||||
async def chat_action(request: Request, action_id: str, form_data: dict, user: Any):
|
||||
|
|
@ -300,8 +356,14 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
|
|||
raise Exception(f"Action not found: {action_id}")
|
||||
|
||||
if not request.app.state.MODELS:
|
||||
await get_all_models(request)
|
||||
models = request.app.state.MODELS
|
||||
await get_all_models(request, user=user)
|
||||
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
models = {
|
||||
request.state.model["id"]: request.state.model,
|
||||
}
|
||||
else:
|
||||
models = request.app.state.MODELS
|
||||
|
||||
data = form_data
|
||||
model_id = data["model"]
|
||||
|
|
@ -375,7 +437,7 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
|
|||
)
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Failed to get user values: {e}")
|
||||
|
||||
params = {**params, "__user__": __user__}
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,210 @@
|
|||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
import aiohttp
|
||||
import websockets
|
||||
from pydantic import BaseModel
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
|
||||
class ResultModel(BaseModel):
|
||||
"""
|
||||
Execute Code Result Model
|
||||
"""
|
||||
|
||||
stdout: Optional[str] = ""
|
||||
stderr: Optional[str] = ""
|
||||
result: Optional[str] = ""
|
||||
|
||||
|
||||
class JupyterCodeExecuter:
|
||||
"""
|
||||
Execute code in jupyter notebook
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
code: str,
|
||||
token: str = "",
|
||||
password: str = "",
|
||||
timeout: int = 60,
|
||||
):
|
||||
"""
|
||||
:param base_url: Jupyter server URL (e.g., "http://localhost:8888")
|
||||
:param code: Code to execute
|
||||
:param token: Jupyter authentication token (optional)
|
||||
:param password: Jupyter password (optional)
|
||||
:param timeout: WebSocket timeout in seconds (default: 60s)
|
||||
"""
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.code = code
|
||||
self.token = token
|
||||
self.password = password
|
||||
self.timeout = timeout
|
||||
self.kernel_id = ""
|
||||
self.session = aiohttp.ClientSession(base_url=self.base_url)
|
||||
self.params = {}
|
||||
self.result = ResultModel()
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.kernel_id:
|
||||
try:
|
||||
async with self.session.delete(
|
||||
f"/api/kernels/{self.kernel_id}", params=self.params
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
except Exception as err:
|
||||
logger.exception("close kernel failed, %s", err)
|
||||
await self.session.close()
|
||||
|
||||
async def run(self) -> ResultModel:
|
||||
try:
|
||||
await self.sign_in()
|
||||
await self.init_kernel()
|
||||
await self.execute_code()
|
||||
except Exception as err:
|
||||
logger.exception("execute code failed, %s", err)
|
||||
self.result.stderr = f"Error: {err}"
|
||||
return self.result
|
||||
|
||||
async def sign_in(self) -> None:
|
||||
# password authentication
|
||||
if self.password and not self.token:
|
||||
async with self.session.get("/login") as response:
|
||||
response.raise_for_status()
|
||||
xsrf_token = response.cookies["_xsrf"].value
|
||||
if not xsrf_token:
|
||||
raise ValueError("_xsrf token not found")
|
||||
self.session.cookie_jar.update_cookies(response.cookies)
|
||||
self.session.headers.update({"X-XSRFToken": xsrf_token})
|
||||
async with self.session.post(
|
||||
"/login",
|
||||
data={"_xsrf": xsrf_token, "password": self.password},
|
||||
allow_redirects=False,
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
self.session.cookie_jar.update_cookies(response.cookies)
|
||||
|
||||
# token authentication
|
||||
if self.token:
|
||||
self.params.update({"token": self.token})
|
||||
|
||||
async def init_kernel(self) -> None:
|
||||
async with self.session.post(
|
||||
url="/api/kernels", params=self.params
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
kernel_data = await response.json()
|
||||
self.kernel_id = kernel_data["id"]
|
||||
|
||||
def init_ws(self) -> (str, dict):
|
||||
ws_base = self.base_url.replace("http", "ws")
|
||||
ws_params = "?" + "&".join([f"{key}={val}" for key, val in self.params.items()])
|
||||
websocket_url = f"{ws_base}/api/kernels/{self.kernel_id}/channels{ws_params if len(ws_params) > 1 else ''}"
|
||||
ws_headers = {}
|
||||
if self.password and not self.token:
|
||||
ws_headers = {
|
||||
"Cookie": "; ".join(
|
||||
[
|
||||
f"{cookie.key}={cookie.value}"
|
||||
for cookie in self.session.cookie_jar
|
||||
]
|
||||
),
|
||||
**self.session.headers,
|
||||
}
|
||||
return websocket_url, ws_headers
|
||||
|
||||
async def execute_code(self) -> None:
|
||||
# initialize ws
|
||||
websocket_url, ws_headers = self.init_ws()
|
||||
# execute
|
||||
async with websockets.connect(
|
||||
websocket_url, additional_headers=ws_headers
|
||||
) as ws:
|
||||
await self.execute_in_jupyter(ws)
|
||||
|
||||
async def execute_in_jupyter(self, ws) -> None:
|
||||
# send message
|
||||
msg_id = uuid.uuid4().hex
|
||||
await ws.send(
|
||||
json.dumps(
|
||||
{
|
||||
"header": {
|
||||
"msg_id": msg_id,
|
||||
"msg_type": "execute_request",
|
||||
"username": "user",
|
||||
"session": uuid.uuid4().hex,
|
||||
"date": "",
|
||||
"version": "5.3",
|
||||
},
|
||||
"parent_header": {},
|
||||
"metadata": {},
|
||||
"content": {
|
||||
"code": self.code,
|
||||
"silent": False,
|
||||
"store_history": True,
|
||||
"user_expressions": {},
|
||||
"allow_stdin": False,
|
||||
"stop_on_error": True,
|
||||
},
|
||||
"channel": "shell",
|
||||
}
|
||||
)
|
||||
)
|
||||
# parse message
|
||||
stdout, stderr, result = "", "", []
|
||||
while True:
|
||||
try:
|
||||
# wait for message
|
||||
message = await asyncio.wait_for(ws.recv(), self.timeout)
|
||||
message_data = json.loads(message)
|
||||
# msg id not match, skip
|
||||
if message_data.get("parent_header", {}).get("msg_id") != msg_id:
|
||||
continue
|
||||
# check message type
|
||||
msg_type = message_data.get("msg_type")
|
||||
match msg_type:
|
||||
case "stream":
|
||||
if message_data["content"]["name"] == "stdout":
|
||||
stdout += message_data["content"]["text"]
|
||||
elif message_data["content"]["name"] == "stderr":
|
||||
stderr += message_data["content"]["text"]
|
||||
case "execute_result" | "display_data":
|
||||
data = message_data["content"]["data"]
|
||||
if "image/png" in data:
|
||||
result.append(f"data:image/png;base64,{data['image/png']}")
|
||||
elif "text/plain" in data:
|
||||
result.append(data["text/plain"])
|
||||
case "error":
|
||||
stderr += "\n".join(message_data["content"]["traceback"])
|
||||
case "status":
|
||||
if message_data["content"]["execution_state"] == "idle":
|
||||
break
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
stderr += "\nExecution timed out."
|
||||
break
|
||||
self.result.stdout = stdout.strip()
|
||||
self.result.stderr = stderr.strip()
|
||||
self.result.result = "\n".join(result).strip() if result else ""
|
||||
|
||||
|
||||
async def execute_code_jupyter(
|
||||
base_url: str, code: str, token: str = "", password: str = "", timeout: int = 60
|
||||
) -> dict:
|
||||
async with JupyterCodeExecuter(
|
||||
base_url, code, token, password, timeout
|
||||
) as executor:
|
||||
result = await executor.run()
|
||||
return result.model_dump()
|
||||
|
|
@ -0,0 +1,111 @@
|
|||
import inspect
|
||||
import logging
|
||||
|
||||
from open_webui.utils.plugin import load_function_module_by_id
|
||||
from open_webui.models.functions import Functions
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
|
||||
def get_sorted_filter_ids(model: dict):
|
||||
def get_priority(function_id):
|
||||
function = Functions.get_function_by_id(function_id)
|
||||
if function is not None and hasattr(function, "valves"):
|
||||
# TODO: Fix FunctionModel to include vavles
|
||||
return (function.valves if function.valves else {}).get("priority", 0)
|
||||
return 0
|
||||
|
||||
filter_ids = [function.id for function in Functions.get_global_filter_functions()]
|
||||
if "info" in model and "meta" in model["info"]:
|
||||
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
|
||||
filter_ids = list(set(filter_ids))
|
||||
|
||||
enabled_filter_ids = [
|
||||
function.id
|
||||
for function in Functions.get_functions_by_type("filter", active_only=True)
|
||||
]
|
||||
|
||||
filter_ids = [fid for fid in filter_ids if fid in enabled_filter_ids]
|
||||
filter_ids.sort(key=get_priority)
|
||||
return filter_ids
|
||||
|
||||
|
||||
async def process_filter_functions(
|
||||
request, filter_functions, filter_type, form_data, extra_params
|
||||
):
|
||||
skip_files = None
|
||||
|
||||
for function in filter_functions:
|
||||
filter = function
|
||||
filter_id = function.id
|
||||
if not filter:
|
||||
continue
|
||||
|
||||
if filter_id in request.app.state.FUNCTIONS:
|
||||
function_module = request.app.state.FUNCTIONS[filter_id]
|
||||
else:
|
||||
function_module, _, _ = load_function_module_by_id(filter_id)
|
||||
request.app.state.FUNCTIONS[filter_id] = function_module
|
||||
|
||||
# Prepare handler function
|
||||
handler = getattr(function_module, filter_type, None)
|
||||
if not handler:
|
||||
continue
|
||||
|
||||
# Check if the function has a file_handler variable
|
||||
if filter_type == "inlet" and hasattr(function_module, "file_handler"):
|
||||
skip_files = function_module.file_handler
|
||||
|
||||
# Apply valves to the function
|
||||
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
|
||||
valves = Functions.get_function_valves_by_id(filter_id)
|
||||
function_module.valves = function_module.Valves(
|
||||
**(valves if valves else {})
|
||||
)
|
||||
|
||||
try:
|
||||
# Prepare parameters
|
||||
sig = inspect.signature(handler)
|
||||
|
||||
params = {"body": form_data}
|
||||
if filter_type == "stream":
|
||||
params = {"event": form_data}
|
||||
|
||||
params = params | {
|
||||
k: v
|
||||
for k, v in {
|
||||
**extra_params,
|
||||
"__id__": filter_id,
|
||||
}.items()
|
||||
if k in sig.parameters
|
||||
}
|
||||
|
||||
# Handle user parameters
|
||||
if "__user__" in sig.parameters:
|
||||
if hasattr(function_module, "UserValves"):
|
||||
try:
|
||||
params["__user__"]["valves"] = function_module.UserValves(
|
||||
**Functions.get_user_valves_by_id_and_user_id(
|
||||
filter_id, params["__user__"]["id"]
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(f"Failed to get user values: {e}")
|
||||
|
||||
# Execute handler
|
||||
if inspect.iscoroutinefunction(handler):
|
||||
form_data = await handler(**params)
|
||||
else:
|
||||
form_data = handler(**params)
|
||||
|
||||
except Exception as e:
|
||||
log.exception(f"Error in {filter_type} handler {filter_id}: {e}")
|
||||
raise e
|
||||
|
||||
# Handle file cleanup for inlet
|
||||
if skip_files and "files" in form_data.get("metadata", {}):
|
||||
del form_data["metadata"]["files"]
|
||||
|
||||
return form_data, {}
|
||||
|
|
@ -161,7 +161,7 @@ async def comfyui_generate_image(
|
|||
seed = (
|
||||
payload.seed
|
||||
if payload.seed
|
||||
else random.randint(0, 18446744073709551614)
|
||||
else random.randint(0, 1125899906842624)
|
||||
)
|
||||
for node_id in node.node_ids:
|
||||
workflow[node_id]["inputs"][node.key] = seed
|
||||
|
|
|
|||
|
|
@ -0,0 +1,140 @@
|
|||
import json
|
||||
import logging
|
||||
import sys
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from open_webui.env import (
|
||||
AUDIT_LOG_FILE_ROTATION_SIZE,
|
||||
AUDIT_LOG_LEVEL,
|
||||
AUDIT_LOGS_FILE_PATH,
|
||||
GLOBAL_LOG_LEVEL,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from loguru import Record
|
||||
|
||||
|
||||
def stdout_format(record: "Record") -> str:
|
||||
"""
|
||||
Generates a formatted string for log records that are output to the console. This format includes a timestamp, log level, source location (module, function, and line), the log message, and any extra data (serialized as JSON).
|
||||
|
||||
Parameters:
|
||||
record (Record): A Loguru record that contains logging details including time, level, name, function, line, message, and any extra context.
|
||||
Returns:
|
||||
str: A formatted log string intended for stdout.
|
||||
"""
|
||||
record["extra"]["extra_json"] = json.dumps(record["extra"])
|
||||
return (
|
||||
"<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
|
||||
"<level>{level: <8}</level> | "
|
||||
"<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - "
|
||||
"<level>{message}</level> - {extra[extra_json]}"
|
||||
"\n{exception}"
|
||||
)
|
||||
|
||||
|
||||
class InterceptHandler(logging.Handler):
|
||||
"""
|
||||
Intercepts log records from Python's standard logging module
|
||||
and redirects them to Loguru's logger.
|
||||
"""
|
||||
|
||||
def emit(self, record):
|
||||
"""
|
||||
Called by the standard logging module for each log event.
|
||||
It transforms the standard `LogRecord` into a format compatible with Loguru
|
||||
and passes it to Loguru's logger.
|
||||
"""
|
||||
try:
|
||||
level = logger.level(record.levelname).name
|
||||
except ValueError:
|
||||
level = record.levelno
|
||||
|
||||
frame, depth = sys._getframe(6), 6
|
||||
while frame and frame.f_code.co_filename == logging.__file__:
|
||||
frame = frame.f_back
|
||||
depth += 1
|
||||
|
||||
logger.opt(depth=depth, exception=record.exc_info).log(
|
||||
level, record.getMessage()
|
||||
)
|
||||
|
||||
|
||||
def file_format(record: "Record"):
|
||||
"""
|
||||
Formats audit log records into a structured JSON string for file output.
|
||||
|
||||
Parameters:
|
||||
record (Record): A Loguru record containing extra audit data.
|
||||
Returns:
|
||||
str: A JSON-formatted string representing the audit data.
|
||||
"""
|
||||
|
||||
audit_data = {
|
||||
"id": record["extra"].get("id", ""),
|
||||
"timestamp": int(record["time"].timestamp()),
|
||||
"user": record["extra"].get("user", dict()),
|
||||
"audit_level": record["extra"].get("audit_level", ""),
|
||||
"verb": record["extra"].get("verb", ""),
|
||||
"request_uri": record["extra"].get("request_uri", ""),
|
||||
"response_status_code": record["extra"].get("response_status_code", 0),
|
||||
"source_ip": record["extra"].get("source_ip", ""),
|
||||
"user_agent": record["extra"].get("user_agent", ""),
|
||||
"request_object": record["extra"].get("request_object", b""),
|
||||
"response_object": record["extra"].get("response_object", b""),
|
||||
"extra": record["extra"].get("extra", {}),
|
||||
}
|
||||
|
||||
record["extra"]["file_extra"] = json.dumps(audit_data, default=str)
|
||||
return "{extra[file_extra]}\n"
|
||||
|
||||
|
||||
def start_logger():
|
||||
"""
|
||||
Initializes and configures Loguru's logger with distinct handlers:
|
||||
|
||||
A console (stdout) handler for general log messages (excluding those marked as auditable).
|
||||
An optional file handler for audit logs if audit logging is enabled.
|
||||
Additionally, this function reconfigures Python’s standard logging to route through Loguru and adjusts logging levels for Uvicorn.
|
||||
|
||||
Parameters:
|
||||
enable_audit_logging (bool): Determines whether audit-specific log entries should be recorded to file.
|
||||
"""
|
||||
logger.remove()
|
||||
|
||||
logger.add(
|
||||
sys.stdout,
|
||||
level=GLOBAL_LOG_LEVEL,
|
||||
format=stdout_format,
|
||||
filter=lambda record: "auditable" not in record["extra"],
|
||||
)
|
||||
|
||||
if AUDIT_LOG_LEVEL != "NONE":
|
||||
try:
|
||||
logger.add(
|
||||
AUDIT_LOGS_FILE_PATH,
|
||||
level="INFO",
|
||||
rotation=AUDIT_LOG_FILE_ROTATION_SIZE,
|
||||
compression="zip",
|
||||
format=file_format,
|
||||
filter=lambda record: record["extra"].get("auditable") is True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize audit log file handler: {str(e)}")
|
||||
|
||||
logging.basicConfig(
|
||||
handlers=[InterceptHandler()], level=GLOBAL_LOG_LEVEL, force=True
|
||||
)
|
||||
for uvicorn_logger_name in ["uvicorn", "uvicorn.error"]:
|
||||
uvicorn_logger = logging.getLogger(uvicorn_logger_name)
|
||||
uvicorn_logger.setLevel(GLOBAL_LOG_LEVEL)
|
||||
uvicorn_logger.handlers = []
|
||||
for uvicorn_logger_name in ["uvicorn.access"]:
|
||||
uvicorn_logger = logging.getLogger(uvicorn_logger_name)
|
||||
uvicorn_logger.setLevel(GLOBAL_LOG_LEVEL)
|
||||
uvicorn_logger.handlers = [InterceptHandler()]
|
||||
|
||||
logger.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}")
|
||||
|
|
@ -2,9 +2,27 @@ import hashlib
|
|||
import re
|
||||
import time
|
||||
import uuid
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
import json
|
||||
|
||||
|
||||
import collections.abc
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
|
||||
def deep_update(d, u):
|
||||
for k, v in u.items():
|
||||
if isinstance(v, collections.abc.Mapping):
|
||||
d[k] = deep_update(d.get(k, {}), v)
|
||||
else:
|
||||
d[k] = v
|
||||
return d
|
||||
|
||||
|
||||
def get_message_list(messages, message_id):
|
||||
|
|
@ -20,7 +38,7 @@ def get_message_list(messages, message_id):
|
|||
current_message = messages.get(message_id)
|
||||
|
||||
if not current_message:
|
||||
return f"Message ID {message_id} not found in the history."
|
||||
return None
|
||||
|
||||
# Reconstruct the chain by following the parentId links
|
||||
message_list = []
|
||||
|
|
@ -131,6 +149,44 @@ def add_or_update_system_message(content: str, messages: list[dict]):
|
|||
return messages
|
||||
|
||||
|
||||
def add_or_update_user_message(content: str, messages: list[dict]):
|
||||
"""
|
||||
Adds a new user message at the end of the messages list
|
||||
or updates the existing user message at the end.
|
||||
|
||||
:param msg: The message to be added or appended.
|
||||
:param messages: The list of message dictionaries.
|
||||
:return: The updated list of message dictionaries.
|
||||
"""
|
||||
|
||||
if messages and messages[-1].get("role") == "user":
|
||||
messages[-1]["content"] = f"{messages[-1]['content']}\n{content}"
|
||||
else:
|
||||
# Insert at the end
|
||||
messages.append({"role": "user", "content": content})
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def append_or_update_assistant_message(content: str, messages: list[dict]):
|
||||
"""
|
||||
Adds a new assistant message at the end of the messages list
|
||||
or updates the existing assistant message at the end.
|
||||
|
||||
:param msg: The message to be added or appended.
|
||||
:param messages: The list of message dictionaries.
|
||||
:return: The updated list of message dictionaries.
|
||||
"""
|
||||
|
||||
if messages and messages[-1].get("role") == "assistant":
|
||||
messages[-1]["content"] = f"{messages[-1]['content']}\n{content}"
|
||||
else:
|
||||
# Insert at the end
|
||||
messages.append({"role": "assistant", "content": content})
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def openai_chat_message_template(model: str):
|
||||
return {
|
||||
"id": f"{model}-{str(uuid.uuid4())}",
|
||||
|
|
@ -141,13 +197,24 @@ def openai_chat_message_template(model: str):
|
|||
|
||||
|
||||
def openai_chat_chunk_message_template(
|
||||
model: str, message: Optional[str] = None, usage: Optional[dict] = None
|
||||
model: str,
|
||||
content: Optional[str] = None,
|
||||
tool_calls: Optional[list[dict]] = None,
|
||||
usage: Optional[dict] = None,
|
||||
) -> dict:
|
||||
template = openai_chat_message_template(model)
|
||||
template["object"] = "chat.completion.chunk"
|
||||
if message:
|
||||
template["choices"][0]["delta"] = {"content": message}
|
||||
else:
|
||||
|
||||
template["choices"][0]["index"] = 0
|
||||
template["choices"][0]["delta"] = {}
|
||||
|
||||
if content:
|
||||
template["choices"][0]["delta"]["content"] = content
|
||||
|
||||
if tool_calls:
|
||||
template["choices"][0]["delta"]["tool_calls"] = tool_calls
|
||||
|
||||
if not content and not tool_calls:
|
||||
template["choices"][0]["finish_reason"] = "stop"
|
||||
|
||||
if usage:
|
||||
|
|
@ -156,12 +223,20 @@ def openai_chat_chunk_message_template(
|
|||
|
||||
|
||||
def openai_chat_completion_message_template(
|
||||
model: str, message: Optional[str] = None, usage: Optional[dict] = None
|
||||
model: str,
|
||||
message: Optional[str] = None,
|
||||
tool_calls: Optional[list[dict]] = None,
|
||||
usage: Optional[dict] = None,
|
||||
) -> dict:
|
||||
template = openai_chat_message_template(model)
|
||||
template["object"] = "chat.completion"
|
||||
if message is not None:
|
||||
template["choices"][0]["message"] = {"content": message, "role": "assistant"}
|
||||
template["choices"][0]["message"] = {
|
||||
"content": message,
|
||||
"role": "assistant",
|
||||
**({"tool_calls": tool_calls} if tool_calls else {}),
|
||||
}
|
||||
|
||||
template["choices"][0]["finish_reason"] = "stop"
|
||||
|
||||
if usage:
|
||||
|
|
@ -183,11 +258,12 @@ def get_gravatar_url(email):
|
|||
return f"https://www.gravatar.com/avatar/{hash_hex}?d=mp"
|
||||
|
||||
|
||||
def calculate_sha256(file):
|
||||
def calculate_sha256(file_path, chunk_size):
|
||||
# Compute SHA-256 hash of a file efficiently in chunks
|
||||
sha256 = hashlib.sha256()
|
||||
# Read the file in chunks to efficiently handle large files
|
||||
for chunk in iter(lambda: file.read(8192), b""):
|
||||
sha256.update(chunk)
|
||||
with open(file_path, "rb") as f:
|
||||
while chunk := f.read(chunk_size):
|
||||
sha256.update(chunk)
|
||||
return sha256.hexdigest()
|
||||
|
||||
|
||||
|
|
@ -342,7 +418,7 @@ def parse_ollama_modelfile(model_text):
|
|||
elif param_type is bool:
|
||||
value = value.lower() == "true"
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(f"Failed to parse parameter {param}: {e}")
|
||||
continue
|
||||
|
||||
data["params"][param] = value
|
||||
|
|
@ -375,3 +451,15 @@ def parse_ollama_modelfile(model_text):
|
|||
data["params"]["messages"] = messages
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def convert_logit_bias_input_to_json(user_input):
|
||||
logit_bias_pairs = user_input.split(",")
|
||||
logit_bias_json = {}
|
||||
for pair in logit_bias_pairs:
|
||||
token, bias = pair.split(":")
|
||||
token = str(token.strip())
|
||||
bias = int(bias.strip())
|
||||
bias = 100 if bias > 100 else -100 if bias < -100 else bias
|
||||
logit_bias_json[token] = bias
|
||||
return json.dumps(logit_bias_json)
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from open_webui.config import (
|
|||
)
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
|
||||
from open_webui.models.users import UserModel
|
||||
|
||||
|
||||
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
||||
|
|
@ -29,17 +30,17 @@ log = logging.getLogger(__name__)
|
|||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
|
||||
async def get_all_base_models(request: Request):
|
||||
async def get_all_base_models(request: Request, user: UserModel = None):
|
||||
function_models = []
|
||||
openai_models = []
|
||||
ollama_models = []
|
||||
|
||||
if request.app.state.config.ENABLE_OPENAI_API:
|
||||
openai_models = await openai.get_all_models(request)
|
||||
openai_models = await openai.get_all_models(request, user=user)
|
||||
openai_models = openai_models["data"]
|
||||
|
||||
if request.app.state.config.ENABLE_OLLAMA_API:
|
||||
ollama_models = await ollama.get_all_models(request)
|
||||
ollama_models = await ollama.get_all_models(request, user=user)
|
||||
ollama_models = [
|
||||
{
|
||||
"id": model["model"],
|
||||
|
|
@ -58,8 +59,8 @@ async def get_all_base_models(request: Request):
|
|||
return models
|
||||
|
||||
|
||||
async def get_all_models(request):
|
||||
models = await get_all_base_models(request)
|
||||
async def get_all_models(request, user: UserModel = None):
|
||||
models = await get_all_base_models(request, user=user)
|
||||
|
||||
# If there are no models, return an empty list
|
||||
if len(models) == 0:
|
||||
|
|
@ -142,7 +143,7 @@ async def get_all_models(request):
|
|||
custom_model.base_model_id == model["id"]
|
||||
or custom_model.base_model_id == model["id"].split(":")[0]
|
||||
):
|
||||
owned_by = model["owned_by"]
|
||||
owned_by = model.get("owned_by", "unknown owner")
|
||||
if "pipe" in model:
|
||||
pipe = model["pipe"]
|
||||
break
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import base64
|
||||
import logging
|
||||
import mimetypes
|
||||
import sys
|
||||
import uuid
|
||||
|
||||
import aiohttp
|
||||
|
|
@ -35,12 +36,20 @@ from open_webui.config import (
|
|||
AppConfig,
|
||||
)
|
||||
from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
|
||||
from open_webui.env import WEBUI_SESSION_COOKIE_SAME_SITE, WEBUI_SESSION_COOKIE_SECURE
|
||||
from open_webui.env import (
|
||||
WEBUI_NAME,
|
||||
WEBUI_AUTH_COOKIE_SAME_SITE,
|
||||
WEBUI_AUTH_COOKIE_SECURE,
|
||||
)
|
||||
from open_webui.utils.misc import parse_duration
|
||||
from open_webui.utils.auth import get_password_hash, create_token
|
||||
from open_webui.utils.webhook import post_webhook
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
|
||||
|
||||
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["OAUTH"])
|
||||
|
||||
auth_manager_config = AppConfig()
|
||||
auth_manager_config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
|
||||
|
|
@ -61,8 +70,9 @@ auth_manager_config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
|
|||
|
||||
|
||||
class OAuthManager:
|
||||
def __init__(self):
|
||||
def __init__(self, app):
|
||||
self.oauth = OAuth()
|
||||
self.app = app
|
||||
for _, provider_config in OAUTH_PROVIDERS.items():
|
||||
provider_config["register"](self.oauth)
|
||||
|
||||
|
|
@ -72,17 +82,21 @@ class OAuthManager:
|
|||
def get_user_role(self, user, user_data):
|
||||
if user and Users.get_num_users() == 1:
|
||||
# If the user is the only user, assign the role "admin" - actually repairs role for single user on login
|
||||
log.debug("Assigning the only user the admin role")
|
||||
return "admin"
|
||||
if not user and Users.get_num_users() == 0:
|
||||
# If there are no users, assign the role "admin", as the first user will be an admin
|
||||
log.debug("Assigning the first user the admin role")
|
||||
return "admin"
|
||||
|
||||
if auth_manager_config.ENABLE_OAUTH_ROLE_MANAGEMENT:
|
||||
log.debug("Running OAUTH Role management")
|
||||
oauth_claim = auth_manager_config.OAUTH_ROLES_CLAIM
|
||||
oauth_allowed_roles = auth_manager_config.OAUTH_ALLOWED_ROLES
|
||||
oauth_admin_roles = auth_manager_config.OAUTH_ADMIN_ROLES
|
||||
oauth_roles = None
|
||||
role = "pending" # Default/fallback role if no matching roles are found
|
||||
# Default/fallback role if no matching roles are found
|
||||
role = auth_manager_config.DEFAULT_USER_ROLE
|
||||
|
||||
# Next block extracts the roles from the user data, accepting nested claims of any depth
|
||||
if oauth_claim and oauth_allowed_roles and oauth_admin_roles:
|
||||
|
|
@ -92,17 +106,24 @@ class OAuthManager:
|
|||
claim_data = claim_data.get(nested_claim, {})
|
||||
oauth_roles = claim_data if isinstance(claim_data, list) else None
|
||||
|
||||
log.debug(f"Oauth Roles claim: {oauth_claim}")
|
||||
log.debug(f"User roles from oauth: {oauth_roles}")
|
||||
log.debug(f"Accepted user roles: {oauth_allowed_roles}")
|
||||
log.debug(f"Accepted admin roles: {oauth_admin_roles}")
|
||||
|
||||
# If any roles are found, check if they match the allowed or admin roles
|
||||
if oauth_roles:
|
||||
# If role management is enabled, and matching roles are provided, use the roles
|
||||
for allowed_role in oauth_allowed_roles:
|
||||
# If the user has any of the allowed roles, assign the role "user"
|
||||
if allowed_role in oauth_roles:
|
||||
log.debug("Assigned user the user role")
|
||||
role = "user"
|
||||
break
|
||||
for admin_role in oauth_admin_roles:
|
||||
# If the user has any of the admin roles, assign the role "admin"
|
||||
if admin_role in oauth_roles:
|
||||
log.debug("Assigned user the admin role")
|
||||
role = "admin"
|
||||
break
|
||||
else:
|
||||
|
|
@ -116,16 +137,34 @@ class OAuthManager:
|
|||
return role
|
||||
|
||||
def update_user_groups(self, user, user_data, default_permissions):
|
||||
log.debug("Running OAUTH Group management")
|
||||
oauth_claim = auth_manager_config.OAUTH_GROUPS_CLAIM
|
||||
|
||||
user_oauth_groups: list[str] = user_data.get(oauth_claim, list())
|
||||
# Nested claim search for groups claim
|
||||
if oauth_claim:
|
||||
claim_data = user_data
|
||||
nested_claims = oauth_claim.split(".")
|
||||
for nested_claim in nested_claims:
|
||||
claim_data = claim_data.get(nested_claim, {})
|
||||
user_oauth_groups = claim_data if isinstance(claim_data, list) else []
|
||||
|
||||
user_current_groups: list[GroupModel] = Groups.get_groups_by_member_id(user.id)
|
||||
all_available_groups: list[GroupModel] = Groups.get_groups()
|
||||
|
||||
log.debug(f"Oauth Groups claim: {oauth_claim}")
|
||||
log.debug(f"User oauth groups: {user_oauth_groups}")
|
||||
log.debug(f"User's current groups: {[g.name for g in user_current_groups]}")
|
||||
log.debug(
|
||||
f"All groups available in OpenWebUI: {[g.name for g in all_available_groups]}"
|
||||
)
|
||||
|
||||
# Remove groups that user is no longer a part of
|
||||
for group_model in user_current_groups:
|
||||
if group_model.name not in user_oauth_groups:
|
||||
# Remove group from user
|
||||
log.debug(
|
||||
f"Removing user from group {group_model.name} as it is no longer in their oauth groups"
|
||||
)
|
||||
|
||||
user_ids = group_model.user_ids
|
||||
user_ids = [i for i in user_ids if i != user.id]
|
||||
|
|
@ -151,6 +190,9 @@ class OAuthManager:
|
|||
gm.name == group_model.name for gm in user_current_groups
|
||||
):
|
||||
# Add user to group
|
||||
log.debug(
|
||||
f"Adding user to group {group_model.name} as it was found in their oauth groups"
|
||||
)
|
||||
|
||||
user_ids = group_model.user_ids
|
||||
user_ids.append(user.id)
|
||||
|
|
@ -170,7 +212,7 @@ class OAuthManager:
|
|||
id=group_model.id, form_data=update_form, overwrite=False
|
||||
)
|
||||
|
||||
async def handle_login(self, provider, request):
|
||||
async def handle_login(self, request, provider):
|
||||
if provider not in OAUTH_PROVIDERS:
|
||||
raise HTTPException(404)
|
||||
# If the provider has a custom redirect URL, use that, otherwise automatically generate one
|
||||
|
|
@ -182,7 +224,7 @@ class OAuthManager:
|
|||
raise HTTPException(404)
|
||||
return await client.authorize_redirect(request, redirect_uri)
|
||||
|
||||
async def handle_callback(self, provider, request, response):
|
||||
async def handle_callback(self, request, provider, response):
|
||||
if provider not in OAUTH_PROVIDERS:
|
||||
raise HTTPException(404)
|
||||
client = self.get_client(provider)
|
||||
|
|
@ -192,7 +234,7 @@ class OAuthManager:
|
|||
log.warning(f"OAuth callback error: {e}")
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||
user_data: UserInfo = token.get("userinfo")
|
||||
if not user_data:
|
||||
if not user_data or auth_manager_config.OAUTH_EMAIL_CLAIM not in user_data:
|
||||
user_data: UserInfo = await client.userinfo(token=token)
|
||||
if not user_data:
|
||||
log.warning(f"OAuth callback failed, user data is missing: {token}")
|
||||
|
|
@ -204,11 +246,46 @@ class OAuthManager:
|
|||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||
provider_sub = f"{provider}@{sub}"
|
||||
email_claim = auth_manager_config.OAUTH_EMAIL_CLAIM
|
||||
email = user_data.get(email_claim, "").lower()
|
||||
email = user_data.get(email_claim, "")
|
||||
# We currently mandate that email addresses are provided
|
||||
if not email:
|
||||
log.warning(f"OAuth callback failed, email is missing: {user_data}")
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||
# If the provider is GitHub,and public email is not provided, we can use the access token to fetch the user's email
|
||||
if provider == "github":
|
||||
try:
|
||||
access_token = token.get("access_token")
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
"https://api.github.com/user/emails", headers=headers
|
||||
) as resp:
|
||||
if resp.ok:
|
||||
emails = await resp.json()
|
||||
# use the primary email as the user's email
|
||||
primary_email = next(
|
||||
(e["email"] for e in emails if e.get("primary")),
|
||||
None,
|
||||
)
|
||||
if primary_email:
|
||||
email = primary_email
|
||||
else:
|
||||
log.warning(
|
||||
"No primary email found in GitHub response"
|
||||
)
|
||||
raise HTTPException(
|
||||
400, detail=ERROR_MESSAGES.INVALID_CRED
|
||||
)
|
||||
else:
|
||||
log.warning("Failed to fetch GitHub email")
|
||||
raise HTTPException(
|
||||
400, detail=ERROR_MESSAGES.INVALID_CRED
|
||||
)
|
||||
except Exception as e:
|
||||
log.warning(f"Error fetching GitHub email: {e}")
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||
else:
|
||||
log.warning(f"OAuth callback failed, email is missing: {user_data}")
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||
email = email.lower()
|
||||
if (
|
||||
"*" not in auth_manager_config.OAUTH_ALLOWED_DOMAINS
|
||||
and email.split("@")[-1] not in auth_manager_config.OAUTH_ALLOWED_DOMAINS
|
||||
|
|
@ -236,12 +313,12 @@ class OAuthManager:
|
|||
Users.update_user_role_by_id(user.id, determined_role)
|
||||
|
||||
if not user:
|
||||
user_count = Users.get_num_users()
|
||||
|
||||
# If the user does not exist, check if signups are enabled
|
||||
if auth_manager_config.ENABLE_OAUTH_SIGNUP:
|
||||
# Check if an existing user with the same email already exists
|
||||
existing_user = Users.get_user_by_email(
|
||||
user_data.get("email", "").lower()
|
||||
)
|
||||
existing_user = Users.get_user_by_email(email)
|
||||
if existing_user:
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
||||
|
||||
|
|
@ -260,24 +337,35 @@ class OAuthManager:
|
|||
}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(picture_url, **get_kwargs) as resp:
|
||||
picture = await resp.read()
|
||||
base64_encoded_picture = base64.b64encode(
|
||||
picture
|
||||
).decode("utf-8")
|
||||
guessed_mime_type = mimetypes.guess_type(picture_url)[0]
|
||||
if guessed_mime_type is None:
|
||||
# assume JPG, browsers are tolerant enough of image formats
|
||||
guessed_mime_type = "image/jpeg"
|
||||
picture_url = f"data:{guessed_mime_type};base64,{base64_encoded_picture}"
|
||||
if resp.ok:
|
||||
picture = await resp.read()
|
||||
base64_encoded_picture = base64.b64encode(
|
||||
picture
|
||||
).decode("utf-8")
|
||||
guessed_mime_type = mimetypes.guess_type(
|
||||
picture_url
|
||||
)[0]
|
||||
if guessed_mime_type is None:
|
||||
# assume JPG, browsers are tolerant enough of image formats
|
||||
guessed_mime_type = "image/jpeg"
|
||||
picture_url = f"data:{guessed_mime_type};base64,{base64_encoded_picture}"
|
||||
else:
|
||||
picture_url = "/user.png"
|
||||
except Exception as e:
|
||||
log.error(
|
||||
f"Error downloading profile image '{picture_url}': {e}"
|
||||
)
|
||||
picture_url = ""
|
||||
picture_url = "/user.png"
|
||||
if not picture_url:
|
||||
picture_url = "/user.png"
|
||||
|
||||
username_claim = auth_manager_config.OAUTH_USERNAME_CLAIM
|
||||
|
||||
name = user_data.get(username_claim)
|
||||
if not name:
|
||||
log.warning("Username claim is missing, using email as name")
|
||||
name = email
|
||||
|
||||
role = self.get_user_role(None, user_data)
|
||||
|
||||
user = Auths.insert_new_auth(
|
||||
|
|
@ -285,7 +373,7 @@ class OAuthManager:
|
|||
password=get_password_hash(
|
||||
str(uuid.uuid4())
|
||||
), # Random password, not used
|
||||
name=user_data.get(username_claim, "User"),
|
||||
name=name,
|
||||
profile_image_url=picture_url,
|
||||
role=role,
|
||||
oauth_sub=provider_sub,
|
||||
|
|
@ -293,6 +381,7 @@ class OAuthManager:
|
|||
|
||||
if auth_manager_config.WEBHOOK_URL:
|
||||
post_webhook(
|
||||
WEBUI_NAME,
|
||||
auth_manager_config.WEBHOOK_URL,
|
||||
WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
|
||||
{
|
||||
|
|
@ -323,8 +412,8 @@ class OAuthManager:
|
|||
key="token",
|
||||
value=jwt_token,
|
||||
httponly=True, # Ensures the cookie is not accessible via JavaScript
|
||||
samesite=WEBUI_SESSION_COOKIE_SAME_SITE,
|
||||
secure=WEBUI_SESSION_COOKIE_SECURE,
|
||||
samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
|
||||
secure=WEBUI_AUTH_COOKIE_SECURE,
|
||||
)
|
||||
|
||||
if ENABLE_OAUTH_SIGNUP.value:
|
||||
|
|
@ -333,12 +422,9 @@ class OAuthManager:
|
|||
key="oauth_id_token",
|
||||
value=oauth_id_token,
|
||||
httponly=True,
|
||||
samesite=WEBUI_SESSION_COOKIE_SAME_SITE,
|
||||
secure=WEBUI_SESSION_COOKIE_SECURE,
|
||||
samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
|
||||
secure=WEBUI_AUTH_COOKIE_SECURE,
|
||||
)
|
||||
# Redirect back to the frontend with the JWT token
|
||||
redirect_url = f"{request.base_url}auth#token={jwt_token}"
|
||||
return RedirectResponse(url=redirect_url, headers=response.headers)
|
||||
|
||||
|
||||
oauth_manager = OAuthManager()
|
||||
|
|
|
|||
|
|
@ -1,17 +1,27 @@
|
|||
from open_webui.utils.task import prompt_template
|
||||
from open_webui.utils.task import prompt_template, prompt_variables_template
|
||||
from open_webui.utils.misc import (
|
||||
add_or_update_system_message,
|
||||
)
|
||||
|
||||
from typing import Callable, Optional
|
||||
import json
|
||||
|
||||
|
||||
# inplace function: form_data is modified
|
||||
def apply_model_system_prompt_to_body(params: dict, form_data: dict, user) -> dict:
|
||||
def apply_model_system_prompt_to_body(
|
||||
params: dict, form_data: dict, metadata: Optional[dict] = None, user=None
|
||||
) -> dict:
|
||||
system = params.get("system", None)
|
||||
if not system:
|
||||
return form_data
|
||||
|
||||
# Metadata (WebUI Usage)
|
||||
if metadata:
|
||||
variables = metadata.get("variables", {})
|
||||
if variables:
|
||||
system = prompt_variables_template(system, variables)
|
||||
|
||||
# Legacy (API Usage)
|
||||
if user:
|
||||
template_params = {
|
||||
"user_name": user.name,
|
||||
|
|
@ -19,7 +29,9 @@ def apply_model_system_prompt_to_body(params: dict, form_data: dict, user) -> di
|
|||
}
|
||||
else:
|
||||
template_params = {}
|
||||
|
||||
system = prompt_template(system, **template_params)
|
||||
|
||||
form_data["messages"] = add_or_update_system_message(
|
||||
system, form_data.get("messages", [])
|
||||
)
|
||||
|
|
@ -50,43 +62,55 @@ def apply_model_params_to_body_openai(params: dict, form_data: dict) -> dict:
|
|||
"reasoning_effort": str,
|
||||
"seed": lambda x: x,
|
||||
"stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x],
|
||||
"logit_bias": lambda x: x,
|
||||
}
|
||||
return apply_model_params_to_body(params, form_data, mappings)
|
||||
|
||||
|
||||
def apply_model_params_to_body_ollama(params: dict, form_data: dict) -> dict:
|
||||
opts = [
|
||||
"temperature",
|
||||
"top_p",
|
||||
"seed",
|
||||
"mirostat",
|
||||
"mirostat_eta",
|
||||
"mirostat_tau",
|
||||
"num_ctx",
|
||||
"num_batch",
|
||||
"num_keep",
|
||||
"repeat_last_n",
|
||||
"tfs_z",
|
||||
"top_k",
|
||||
"min_p",
|
||||
"use_mmap",
|
||||
"use_mlock",
|
||||
"num_thread",
|
||||
"num_gpu",
|
||||
]
|
||||
mappings = {i: lambda x: x for i in opts}
|
||||
form_data = apply_model_params_to_body(params, form_data, mappings)
|
||||
|
||||
# Convert OpenAI parameter names to Ollama parameter names if needed.
|
||||
name_differences = {
|
||||
"max_tokens": "num_predict",
|
||||
"frequency_penalty": "repeat_penalty",
|
||||
}
|
||||
|
||||
for key, value in name_differences.items():
|
||||
if (param := params.get(key, None)) is not None:
|
||||
form_data[value] = param
|
||||
# Copy the parameter to new name then delete it, to prevent Ollama warning of invalid option provided
|
||||
params[value] = params[key]
|
||||
del params[key]
|
||||
|
||||
return form_data
|
||||
# See https://github.com/ollama/ollama/blob/main/docs/api.md#request-8
|
||||
mappings = {
|
||||
"temperature": float,
|
||||
"top_p": float,
|
||||
"seed": lambda x: x,
|
||||
"mirostat": int,
|
||||
"mirostat_eta": float,
|
||||
"mirostat_tau": float,
|
||||
"num_ctx": int,
|
||||
"num_batch": int,
|
||||
"num_keep": int,
|
||||
"num_predict": int,
|
||||
"repeat_last_n": int,
|
||||
"top_k": int,
|
||||
"min_p": float,
|
||||
"typical_p": float,
|
||||
"repeat_penalty": float,
|
||||
"presence_penalty": float,
|
||||
"frequency_penalty": float,
|
||||
"penalize_newline": bool,
|
||||
"stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x],
|
||||
"numa": bool,
|
||||
"num_gpu": int,
|
||||
"main_gpu": int,
|
||||
"low_vram": bool,
|
||||
"vocab_only": bool,
|
||||
"use_mmap": bool,
|
||||
"use_mlock": bool,
|
||||
"num_thread": int,
|
||||
}
|
||||
|
||||
return apply_model_params_to_body(params, form_data, mappings)
|
||||
|
||||
|
||||
def convert_messages_openai_to_ollama(messages: list[dict]) -> list[dict]:
|
||||
|
|
@ -97,11 +121,38 @@ def convert_messages_openai_to_ollama(messages: list[dict]) -> list[dict]:
|
|||
new_message = {"role": message["role"]}
|
||||
|
||||
content = message.get("content", [])
|
||||
tool_calls = message.get("tool_calls", None)
|
||||
tool_call_id = message.get("tool_call_id", None)
|
||||
|
||||
# Check if the content is a string (just a simple message)
|
||||
if isinstance(content, str):
|
||||
if isinstance(content, str) and not tool_calls:
|
||||
# If the content is a string, it's pure text
|
||||
new_message["content"] = content
|
||||
|
||||
# If message is a tool call, add the tool call id to the message
|
||||
if tool_call_id:
|
||||
new_message["tool_call_id"] = tool_call_id
|
||||
|
||||
elif tool_calls:
|
||||
# If tool calls are present, add them to the message
|
||||
ollama_tool_calls = []
|
||||
for tool_call in tool_calls:
|
||||
ollama_tool_call = {
|
||||
"index": tool_call.get("index", 0),
|
||||
"id": tool_call.get("id", None),
|
||||
"function": {
|
||||
"name": tool_call.get("function", {}).get("name", ""),
|
||||
"arguments": json.loads(
|
||||
tool_call.get("function", {}).get("arguments", {})
|
||||
),
|
||||
},
|
||||
}
|
||||
ollama_tool_calls.append(ollama_tool_call)
|
||||
new_message["tool_calls"] = ollama_tool_calls
|
||||
|
||||
# Put the content to empty string (Ollama requires an empty string for tool calls)
|
||||
new_message["content"] = ""
|
||||
|
||||
else:
|
||||
# Otherwise, assume the content is a list of dicts, e.g., text followed by an image URL
|
||||
content_text = ""
|
||||
|
|
@ -155,37 +206,38 @@ def convert_payload_openai_to_ollama(openai_payload: dict) -> dict:
|
|||
)
|
||||
ollama_payload["stream"] = openai_payload.get("stream", False)
|
||||
|
||||
if "tools" in openai_payload:
|
||||
ollama_payload["tools"] = openai_payload["tools"]
|
||||
|
||||
if "format" in openai_payload:
|
||||
ollama_payload["format"] = openai_payload["format"]
|
||||
|
||||
# If there are advanced parameters in the payload, format them in Ollama's options field
|
||||
ollama_options = {}
|
||||
|
||||
if openai_payload.get("options"):
|
||||
ollama_payload["options"] = openai_payload["options"]
|
||||
ollama_options = openai_payload["options"]
|
||||
|
||||
# Handle parameters which map directly
|
||||
for param in ["temperature", "top_p", "seed"]:
|
||||
if param in openai_payload:
|
||||
ollama_options[param] = openai_payload[param]
|
||||
# Re-Mapping OpenAI's `max_tokens` -> Ollama's `num_predict`
|
||||
if "max_tokens" in ollama_options:
|
||||
ollama_options["num_predict"] = ollama_options["max_tokens"]
|
||||
del ollama_options[
|
||||
"max_tokens"
|
||||
] # To prevent Ollama warning of invalid option provided
|
||||
|
||||
# Mapping OpenAI's `max_tokens` -> Ollama's `num_predict`
|
||||
if "max_completion_tokens" in openai_payload:
|
||||
ollama_options["num_predict"] = openai_payload["max_completion_tokens"]
|
||||
elif "max_tokens" in openai_payload:
|
||||
ollama_options["num_predict"] = openai_payload["max_tokens"]
|
||||
# Ollama lacks a "system" prompt option. It has to be provided as a direct parameter, so we copy it down.
|
||||
if "system" in ollama_options:
|
||||
ollama_payload["system"] = ollama_options["system"]
|
||||
del ollama_options[
|
||||
"system"
|
||||
] # To prevent Ollama warning of invalid option provided
|
||||
|
||||
# Handle frequency / presence_penalty, which needs renaming and checking
|
||||
if "frequency_penalty" in openai_payload:
|
||||
ollama_options["repeat_penalty"] = openai_payload["frequency_penalty"]
|
||||
|
||||
if "presence_penalty" in openai_payload and "penalty" not in ollama_options:
|
||||
# We are assuming presence penalty uses a similar concept in Ollama, which needs custom handling if exists.
|
||||
ollama_options["new_topic_penalty"] = openai_payload["presence_penalty"]
|
||||
|
||||
# Add options to payload if any have been set
|
||||
if ollama_options:
|
||||
# If there is the "stop" parameter in the openai_payload, remap it to the ollama_payload.options
|
||||
if "stop" in openai_payload:
|
||||
ollama_options = ollama_payload.get("options", {})
|
||||
ollama_options["stop"] = openai_payload.get("stop")
|
||||
ollama_payload["options"] = ollama_options
|
||||
|
||||
if "metadata" in openai_payload:
|
||||
ollama_payload["metadata"] = openai_payload["metadata"]
|
||||
|
||||
return ollama_payload
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from datetime import datetime
|
|||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List
|
||||
from html import escape
|
||||
|
||||
from markdown import markdown
|
||||
|
||||
|
|
@ -41,13 +42,13 @@ class PDFGenerator:
|
|||
|
||||
def _build_html_message(self, message: Dict[str, Any]) -> str:
|
||||
"""Build HTML for a single message."""
|
||||
role = message.get("role", "user")
|
||||
content = message.get("content", "")
|
||||
role = escape(message.get("role", "user"))
|
||||
content = escape(message.get("content", ""))
|
||||
timestamp = message.get("timestamp")
|
||||
|
||||
model = message.get("model") if role == "assistant" else ""
|
||||
model = escape(message.get("model") if role == "assistant" else "")
|
||||
|
||||
date_str = self.format_timestamp(timestamp) if timestamp else ""
|
||||
date_str = escape(self.format_timestamp(timestamp) if timestamp else "")
|
||||
|
||||
# extends pymdownx extension to convert markdown to html.
|
||||
# - https://facelessuser.github.io/pymdown-extensions/usage_notes/
|
||||
|
|
@ -76,6 +77,7 @@ class PDFGenerator:
|
|||
|
||||
def _generate_html_body(self) -> str:
|
||||
"""Generate the full HTML body for the PDF."""
|
||||
escaped_title = escape(self.form_data.title)
|
||||
return f"""
|
||||
<html>
|
||||
<head>
|
||||
|
|
@ -84,7 +86,7 @@ class PDFGenerator:
|
|||
<body>
|
||||
<div>
|
||||
<div>
|
||||
<h2>{self.form_data.title}</h2>
|
||||
<h2>{escaped_title}</h2>
|
||||
{self.messages_html}
|
||||
</div>
|
||||
</div>
|
||||
|
|
@ -108,7 +110,7 @@ class PDFGenerator:
|
|||
# When running using `pip install -e .` the static directory is in the site packages.
|
||||
# This path only works if `open-webui serve` is run from the root of this project.
|
||||
if not FONTS_DIR.exists():
|
||||
FONTS_DIR = Path("./backend/static/fonts")
|
||||
FONTS_DIR = Path(".") / "backend" / "static" / "fonts"
|
||||
|
||||
pdf.add_font("NotoSans", "", f"{FONTS_DIR}/NotoSans-Regular.ttf")
|
||||
pdf.add_font("NotoSans", "b", f"{FONTS_DIR}/NotoSans-Bold.ttf")
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ def extract_frontmatter(content):
|
|||
frontmatter[key.strip()] = value.strip()
|
||||
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
log.exception(f"Failed to extract frontmatter: {e}")
|
||||
return {}
|
||||
|
||||
return frontmatter
|
||||
|
|
@ -167,9 +167,14 @@ def load_function_module_by_id(function_id, content=None):
|
|||
|
||||
def install_frontmatter_requirements(requirements):
|
||||
if requirements:
|
||||
req_list = [req.strip() for req in requirements.split(",")]
|
||||
for req in req_list:
|
||||
log.info(f"Installing requirement: {req}")
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", req])
|
||||
try:
|
||||
req_list = [req.strip() for req in requirements.split(",")]
|
||||
for req in req_list:
|
||||
log.info(f"Installing requirement: {req}")
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", req])
|
||||
except Exception as e:
|
||||
log.error(f"Error installing package: {req}")
|
||||
raise e
|
||||
|
||||
else:
|
||||
log.info("No requirements found in frontmatter.")
|
||||
|
|
|
|||
|
|
@ -1,15 +1,101 @@
|
|||
import json
|
||||
from uuid import uuid4
|
||||
from open_webui.utils.misc import (
|
||||
openai_chat_chunk_message_template,
|
||||
openai_chat_completion_message_template,
|
||||
)
|
||||
|
||||
|
||||
def convert_ollama_tool_call_to_openai(tool_calls: dict) -> dict:
|
||||
openai_tool_calls = []
|
||||
for tool_call in tool_calls:
|
||||
openai_tool_call = {
|
||||
"index": tool_call.get("index", 0),
|
||||
"id": tool_call.get("id", f"call_{str(uuid4())}"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call.get("function", {}).get("name", ""),
|
||||
"arguments": json.dumps(
|
||||
tool_call.get("function", {}).get("arguments", {})
|
||||
),
|
||||
},
|
||||
}
|
||||
openai_tool_calls.append(openai_tool_call)
|
||||
return openai_tool_calls
|
||||
|
||||
|
||||
def convert_ollama_usage_to_openai(data: dict) -> dict:
|
||||
return {
|
||||
"response_token/s": (
|
||||
round(
|
||||
(
|
||||
(
|
||||
data.get("eval_count", 0)
|
||||
/ ((data.get("eval_duration", 0) / 10_000_000))
|
||||
)
|
||||
* 100
|
||||
),
|
||||
2,
|
||||
)
|
||||
if data.get("eval_duration", 0) > 0
|
||||
else "N/A"
|
||||
),
|
||||
"prompt_token/s": (
|
||||
round(
|
||||
(
|
||||
(
|
||||
data.get("prompt_eval_count", 0)
|
||||
/ ((data.get("prompt_eval_duration", 0) / 10_000_000))
|
||||
)
|
||||
* 100
|
||||
),
|
||||
2,
|
||||
)
|
||||
if data.get("prompt_eval_duration", 0) > 0
|
||||
else "N/A"
|
||||
),
|
||||
"total_duration": data.get("total_duration", 0),
|
||||
"load_duration": data.get("load_duration", 0),
|
||||
"prompt_eval_count": data.get("prompt_eval_count", 0),
|
||||
"prompt_tokens": int(
|
||||
data.get("prompt_eval_count", 0)
|
||||
), # This is the OpenAI compatible key
|
||||
"prompt_eval_duration": data.get("prompt_eval_duration", 0),
|
||||
"eval_count": data.get("eval_count", 0),
|
||||
"completion_tokens": int(
|
||||
data.get("eval_count", 0)
|
||||
), # This is the OpenAI compatible key
|
||||
"eval_duration": data.get("eval_duration", 0),
|
||||
"approximate_total": (lambda s: f"{s // 3600}h{(s % 3600) // 60}m{s % 60}s")(
|
||||
(data.get("total_duration", 0) or 0) // 1_000_000_000
|
||||
),
|
||||
"total_tokens": int( # This is the OpenAI compatible key
|
||||
data.get("prompt_eval_count", 0) + data.get("eval_count", 0)
|
||||
),
|
||||
"completion_tokens_details": { # This is the OpenAI compatible key
|
||||
"reasoning_tokens": 0,
|
||||
"accepted_prediction_tokens": 0,
|
||||
"rejected_prediction_tokens": 0,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def convert_response_ollama_to_openai(ollama_response: dict) -> dict:
|
||||
model = ollama_response.get("model", "ollama")
|
||||
message_content = ollama_response.get("message", {}).get("content", "")
|
||||
tool_calls = ollama_response.get("message", {}).get("tool_calls", None)
|
||||
openai_tool_calls = None
|
||||
|
||||
response = openai_chat_completion_message_template(model, message_content)
|
||||
if tool_calls:
|
||||
openai_tool_calls = convert_ollama_tool_call_to_openai(tool_calls)
|
||||
|
||||
data = ollama_response
|
||||
|
||||
usage = convert_ollama_usage_to_openai(data)
|
||||
|
||||
response = openai_chat_completion_message_template(
|
||||
model, message_content, openai_tool_calls, usage
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
|
|
@ -18,53 +104,21 @@ async def convert_streaming_response_ollama_to_openai(ollama_streaming_response)
|
|||
data = json.loads(data)
|
||||
|
||||
model = data.get("model", "ollama")
|
||||
message_content = data.get("message", {}).get("content", "")
|
||||
message_content = data.get("message", {}).get("content", None)
|
||||
tool_calls = data.get("message", {}).get("tool_calls", None)
|
||||
openai_tool_calls = None
|
||||
|
||||
if tool_calls:
|
||||
openai_tool_calls = convert_ollama_tool_call_to_openai(tool_calls)
|
||||
|
||||
done = data.get("done", False)
|
||||
|
||||
usage = None
|
||||
if done:
|
||||
usage = {
|
||||
"response_token/s": (
|
||||
round(
|
||||
(
|
||||
(
|
||||
data.get("eval_count", 0)
|
||||
/ ((data.get("eval_duration", 0) / 10_000_000))
|
||||
)
|
||||
* 100
|
||||
),
|
||||
2,
|
||||
)
|
||||
if data.get("eval_duration", 0) > 0
|
||||
else "N/A"
|
||||
),
|
||||
"prompt_token/s": (
|
||||
round(
|
||||
(
|
||||
(
|
||||
data.get("prompt_eval_count", 0)
|
||||
/ ((data.get("prompt_eval_duration", 0) / 10_000_000))
|
||||
)
|
||||
* 100
|
||||
),
|
||||
2,
|
||||
)
|
||||
if data.get("prompt_eval_duration", 0) > 0
|
||||
else "N/A"
|
||||
),
|
||||
"total_duration": data.get("total_duration", 0),
|
||||
"load_duration": data.get("load_duration", 0),
|
||||
"prompt_eval_count": data.get("prompt_eval_count", 0),
|
||||
"prompt_eval_duration": data.get("prompt_eval_duration", 0),
|
||||
"eval_count": data.get("eval_count", 0),
|
||||
"eval_duration": data.get("eval_duration", 0),
|
||||
"approximate_total": (
|
||||
lambda s: f"{s // 3600}h{(s % 3600) // 60}m{s % 60}s"
|
||||
)((data.get("total_duration", 0) or 0) // 1_000_000_000),
|
||||
}
|
||||
usage = convert_ollama_usage_to_openai(data)
|
||||
|
||||
data = openai_chat_chunk_message_template(
|
||||
model, message_content if not done else None, usage
|
||||
model, message_content, openai_tool_calls, usage
|
||||
)
|
||||
|
||||
line = f"data: {json.dumps(data)}\n\n"
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ def get_task_model_id(
|
|||
# Set the task model
|
||||
task_model_id = default_model_id
|
||||
# Check if the user has a custom task model and use that model
|
||||
if models[task_model_id]["owned_by"] == "ollama":
|
||||
if models[task_model_id].get("owned_by") == "ollama":
|
||||
if task_model and task_model in models:
|
||||
task_model_id = task_model
|
||||
else:
|
||||
|
|
@ -32,6 +32,12 @@ def get_task_model_id(
|
|||
return task_model_id
|
||||
|
||||
|
||||
def prompt_variables_template(template: str, variables: dict[str, str]) -> str:
|
||||
for variable, value in variables.items():
|
||||
template = template.replace(variable, value)
|
||||
return template
|
||||
|
||||
|
||||
def prompt_template(
|
||||
template: str, user_name: Optional[str] = None, user_location: Optional[str] = None
|
||||
) -> str:
|
||||
|
|
@ -98,7 +104,7 @@ def replace_prompt_variable(template: str, prompt: str) -> str:
|
|||
|
||||
|
||||
def replace_messages_variable(
|
||||
template: str, messages: Optional[list[str]] = None
|
||||
template: str, messages: Optional[list[dict]] = None
|
||||
) -> str:
|
||||
def replacement_function(match):
|
||||
full_match = match.group(0)
|
||||
|
|
|
|||
|
|
@ -61,6 +61,12 @@ def get_tools(
|
|||
)
|
||||
|
||||
for spec in tools.specs:
|
||||
# TODO: Fix hack for OpenAI API
|
||||
# Some times breaks OpenAI but others don't. Leaving the comment
|
||||
for val in spec.get("parameters", {}).get("properties", {}).values():
|
||||
if val["type"] == "str":
|
||||
val["type"] = "string"
|
||||
|
||||
# Remove internal parameters
|
||||
spec["parameters"]["properties"] = {
|
||||
key: val
|
||||
|
|
@ -73,6 +79,13 @@ def get_tools(
|
|||
# convert to function that takes only model params and inserts custom params
|
||||
original_func = getattr(module, function_name)
|
||||
callable = apply_extra_params_to_tool_function(original_func, extra_params)
|
||||
|
||||
if callable.__doc__ and callable.__doc__.strip() != "":
|
||||
s = re.split(":(param|return)", callable.__doc__, 1)
|
||||
spec["description"] = s[0]
|
||||
else:
|
||||
spec["description"] = function_name
|
||||
|
||||
# TODO: This needs to be a pydantic model
|
||||
tool_dict = {
|
||||
"toolkit_id": tool_id,
|
||||
|
|
|
|||
|
|
@ -2,14 +2,14 @@ import json
|
|||
import logging
|
||||
|
||||
import requests
|
||||
from open_webui.config import WEBUI_FAVICON_URL, WEBUI_NAME
|
||||
from open_webui.config import WEBUI_FAVICON_URL
|
||||
from open_webui.env import SRC_LOG_LEVELS, VERSION
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["WEBHOOK"])
|
||||
|
||||
|
||||
def post_webhook(url: str, message: str, event_data: dict) -> bool:
|
||||
def post_webhook(name: str, url: str, message: str, event_data: dict) -> bool:
|
||||
try:
|
||||
log.debug(f"post_webhook: {url}, {message}, {event_data}")
|
||||
payload = {}
|
||||
|
|
@ -39,7 +39,7 @@ def post_webhook(url: str, message: str, event_data: dict) -> bool:
|
|||
"sections": [
|
||||
{
|
||||
"activityTitle": message,
|
||||
"activitySubtitle": f"{WEBUI_NAME} ({VERSION}) - {action}",
|
||||
"activitySubtitle": f"{name} ({VERSION}) - {action}",
|
||||
"activityImage": WEBUI_FAVICON_URL,
|
||||
"facts": facts,
|
||||
"markdown": True,
|
||||
|
|
|
|||
|
|
@ -1,29 +1,26 @@
|
|||
fastapi==0.111.0
|
||||
uvicorn[standard]==0.30.6
|
||||
pydantic==2.9.2
|
||||
fastapi==0.115.7
|
||||
uvicorn[standard]==0.34.0
|
||||
pydantic==2.10.6
|
||||
python-multipart==0.0.18
|
||||
|
||||
Flask==3.1.0
|
||||
Flask-Cors==5.0.0
|
||||
|
||||
python-socketio==5.11.3
|
||||
python-jose==3.3.0
|
||||
python-jose==3.4.0
|
||||
passlib[bcrypt]==1.7.4
|
||||
|
||||
requests==2.32.3
|
||||
aiohttp==3.11.8
|
||||
aiohttp==3.11.11
|
||||
async-timeout
|
||||
aiocache
|
||||
aiofiles
|
||||
|
||||
sqlalchemy==2.0.32
|
||||
sqlalchemy==2.0.38
|
||||
alembic==1.14.0
|
||||
peewee==3.17.8
|
||||
peewee==3.17.9
|
||||
peewee-migrate==1.12.2
|
||||
psycopg2-binary==2.9.9
|
||||
pgvector==0.3.5
|
||||
PyMySQL==1.1.1
|
||||
bcrypt==4.2.0
|
||||
bcrypt==4.3.0
|
||||
|
||||
pymongo
|
||||
redis
|
||||
|
|
@ -32,20 +29,27 @@ boto3==1.35.53
|
|||
argon2-cffi==23.1.0
|
||||
APScheduler==3.10.4
|
||||
|
||||
RestrictedPython==8.0
|
||||
|
||||
loguru==0.7.2
|
||||
asgiref==3.8.1
|
||||
|
||||
# AI libraries
|
||||
openai
|
||||
anthropic
|
||||
google-generativeai==0.7.2
|
||||
tiktoken
|
||||
|
||||
langchain==0.3.7
|
||||
langchain-community==0.3.7
|
||||
langchain==0.3.19
|
||||
langchain-community==0.3.18
|
||||
|
||||
fake-useragent==1.5.1
|
||||
chromadb==0.6.2
|
||||
pymilvus==2.5.0
|
||||
qdrant-client~=1.12.0
|
||||
opensearch-py==2.7.1
|
||||
opensearch-py==2.8.0
|
||||
playwright==1.49.1 # Caution: version must match docker-compose.playwright.yaml
|
||||
elasticsearch==8.17.1
|
||||
|
||||
|
||||
transformers
|
||||
|
|
@ -57,10 +61,10 @@ einops==0.8.0
|
|||
ftfy==6.2.3
|
||||
pypdf==4.3.1
|
||||
fpdf2==2.8.2
|
||||
pymdown-extensions==10.11.2
|
||||
pymdown-extensions==10.14.2
|
||||
docx2txt==0.8
|
||||
python-pptx==1.0.0
|
||||
unstructured==0.15.9
|
||||
unstructured==0.16.17
|
||||
nltk==3.9.1
|
||||
Markdown==3.7
|
||||
pypandoc==1.13
|
||||
|
|
@ -71,25 +75,26 @@ xlrd==2.0.1
|
|||
validators==0.34.0
|
||||
psutil
|
||||
sentencepiece
|
||||
soundfile==0.12.1
|
||||
soundfile==0.13.1
|
||||
azure-ai-documentintelligence==1.0.0
|
||||
|
||||
opencv-python-headless==4.10.0.84
|
||||
opencv-python-headless==4.11.0.86
|
||||
rapidocr-onnxruntime==1.3.24
|
||||
rank-bm25==0.2.2
|
||||
|
||||
faster-whisper==1.0.3
|
||||
faster-whisper==1.1.1
|
||||
|
||||
PyJWT[crypto]==2.10.1
|
||||
authlib==1.3.2
|
||||
authlib==1.4.1
|
||||
|
||||
black==24.8.0
|
||||
black==25.1.0
|
||||
langfuse==2.44.0
|
||||
youtube-transcript-api==0.6.3
|
||||
pytube==15.0.0
|
||||
|
||||
extract_msg
|
||||
pydub
|
||||
duckduckgo-search~=7.2.1
|
||||
duckduckgo-search~=7.3.2
|
||||
|
||||
## Google Drive
|
||||
google-api-python-client
|
||||
|
|
@ -104,5 +109,12 @@ pytest-docker~=3.1.1
|
|||
googleapis-common-protos==1.63.2
|
||||
google-cloud-storage==2.19.0
|
||||
|
||||
azure-identity==1.20.0
|
||||
azure-storage-blob==12.24.1
|
||||
|
||||
|
||||
## LDAP
|
||||
ldap3==2.9.1
|
||||
|
||||
## Firecrawl
|
||||
firecrawl-py==1.12.0
|
||||
|
|
|
|||
|
|
@ -3,6 +3,17 @@
|
|||
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
|
||||
cd "$SCRIPT_DIR" || exit
|
||||
|
||||
# Add conditional Playwright browser installation
|
||||
if [[ "${RAG_WEB_LOADER_ENGINE,,}" == "playwright" ]]; then
|
||||
if [[ -z "${PLAYWRIGHT_WS_URI}" ]]; then
|
||||
echo "Installing Playwright browsers..."
|
||||
playwright install chromium
|
||||
playwright install-deps chromium
|
||||
fi
|
||||
|
||||
python -c "import nltk; nltk.download('punkt_tab')"
|
||||
fi
|
||||
|
||||
KEY_FILE=.webui_secret_key
|
||||
|
||||
PORT="${PORT:-8080}"
|
||||
|
|
|
|||
|
|
@ -6,6 +6,17 @@ SETLOCAL ENABLEDELAYEDEXPANSION
|
|||
SET "SCRIPT_DIR=%~dp0"
|
||||
cd /d "%SCRIPT_DIR%" || exit /b
|
||||
|
||||
:: Add conditional Playwright browser installation
|
||||
IF /I "%RAG_WEB_LOADER_ENGINE%" == "playwright" (
|
||||
IF "%PLAYWRIGHT_WS_URI%" == "" (
|
||||
echo Installing Playwright browsers...
|
||||
playwright install chromium
|
||||
playwright install-deps chromium
|
||||
)
|
||||
|
||||
python -c "import nltk; nltk.download('punkt_tab')"
|
||||
)
|
||||
|
||||
SET "KEY_FILE=.webui_secret_key"
|
||||
IF "%PORT%"=="" SET PORT=8080
|
||||
IF "%HOST%"=="" SET HOST=0.0.0.0
|
||||
|
|
|
|||
|
|
@ -0,0 +1,10 @@
|
|||
services:
|
||||
playwright:
|
||||
image: mcr.microsoft.com/playwright:v1.49.1-noble # Version must match requirements.txt
|
||||
container_name: playwright
|
||||
command: npx -y playwright@1.49.1 run-server --port 3000 --host 0.0.0.0
|
||||
|
||||
open-webui:
|
||||
environment:
|
||||
- 'RAG_WEB_LOADER_ENGINE=playwright'
|
||||
- 'PLAYWRIGHT_WS_URI=ws://playwright:3000'
|
||||