1302 lines
		
	
	
		
			45 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			1302 lines
		
	
	
		
			45 KiB
		
	
	
	
		
			Python
		
	
	
	
| import time
 | |
| import logging
 | |
| import sys
 | |
| 
 | |
| import asyncio
 | |
| from aiocache import cached
 | |
| from typing import Any, Optional
 | |
| import random
 | |
| import json
 | |
| import inspect
 | |
| from uuid import uuid4
 | |
| from concurrent.futures import ThreadPoolExecutor
 | |
| 
 | |
| 
 | |
| from fastapi import Request
 | |
| from fastapi import BackgroundTasks
 | |
| 
 | |
| from starlette.responses import Response, StreamingResponse
 | |
| 
 | |
| 
 | |
| from open_webui.models.chats import Chats
 | |
| from open_webui.models.users import Users
 | |
| from open_webui.socket.main import (
 | |
|     get_event_call,
 | |
|     get_event_emitter,
 | |
|     get_active_status_by_user_id,
 | |
| )
 | |
| from open_webui.routers.tasks import (
 | |
|     generate_queries,
 | |
|     generate_title,
 | |
|     generate_image_prompt,
 | |
|     generate_chat_tags,
 | |
| )
 | |
| from open_webui.routers.retrieval import process_web_search, SearchForm
 | |
| from open_webui.routers.images import image_generations, GenerateImageForm
 | |
| 
 | |
| 
 | |
| from open_webui.utils.webhook import post_webhook
 | |
| 
 | |
| 
 | |
| from open_webui.models.users import UserModel
 | |
| from open_webui.models.functions import Functions
 | |
| from open_webui.models.models import Models
 | |
| 
 | |
| from open_webui.retrieval.utils import get_sources_from_files
 | |
| 
 | |
| 
 | |
| from open_webui.utils.chat import generate_chat_completion
 | |
| from open_webui.utils.task import (
 | |
|     get_task_model_id,
 | |
|     rag_template,
 | |
|     tools_function_calling_generation_template,
 | |
| )
 | |
| from open_webui.utils.misc import (
 | |
|     get_message_list,
 | |
|     add_or_update_system_message,
 | |
|     get_last_user_message,
 | |
|     get_last_assistant_message,
 | |
|     prepend_to_first_user_message_content,
 | |
| )
 | |
| from open_webui.utils.tools import get_tools
 | |
| from open_webui.utils.plugin import load_function_module_by_id
 | |
| 
 | |
| 
 | |
| from open_webui.tasks import create_task
 | |
| 
 | |
| from open_webui.config import DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
 | |
| from open_webui.env import (
 | |
|     SRC_LOG_LEVELS,
 | |
|     GLOBAL_LOG_LEVEL,
 | |
|     BYPASS_MODEL_ACCESS_CONTROL,
 | |
|     ENABLE_REALTIME_CHAT_SAVE,
 | |
| )
 | |
| from open_webui.constants import TASKS
 | |
| 
 | |
| 
 | |
| logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
 | |
| log = logging.getLogger(__name__)
 | |
| log.setLevel(SRC_LOG_LEVELS["MAIN"])
 | |
| 
 | |
| 
 | |
| async def chat_completion_filter_functions_handler(request, body, model, extra_params):
 | |
|     skip_files = None
 | |
| 
 | |
|     def get_filter_function_ids(model):
 | |
|         def get_priority(function_id):
 | |
|             function = Functions.get_function_by_id(function_id)
 | |
|             if function is not None and hasattr(function, "valves"):
 | |
|                 # TODO: Fix FunctionModel
 | |
|                 return (function.valves if function.valves else {}).get("priority", 0)
 | |
|             return 0
 | |
| 
 | |
|         filter_ids = [
 | |
|             function.id for function in Functions.get_global_filter_functions()
 | |
|         ]
 | |
|         if "info" in model and "meta" in model["info"]:
 | |
|             filter_ids.extend(model["info"]["meta"].get("filterIds", []))
 | |
|             filter_ids = list(set(filter_ids))
 | |
| 
 | |
|         enabled_filter_ids = [
 | |
|             function.id
 | |
|             for function in Functions.get_functions_by_type("filter", active_only=True)
 | |
|         ]
 | |
| 
 | |
|         filter_ids = [
 | |
|             filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
 | |
|         ]
 | |
| 
 | |
|         filter_ids.sort(key=get_priority)
 | |
|         return filter_ids
 | |
| 
 | |
|     filter_ids = get_filter_function_ids(model)
 | |
|     for filter_id in filter_ids:
 | |
|         filter = Functions.get_function_by_id(filter_id)
 | |
|         if not filter:
 | |
|             continue
 | |
| 
 | |
|         if filter_id in request.app.state.FUNCTIONS:
 | |
|             function_module = request.app.state.FUNCTIONS[filter_id]
 | |
|         else:
 | |
|             function_module, _, _ = load_function_module_by_id(filter_id)
 | |
|             request.app.state.FUNCTIONS[filter_id] = function_module
 | |
| 
 | |
|         # Check if the function has a file_handler variable
 | |
|         if hasattr(function_module, "file_handler"):
 | |
|             skip_files = function_module.file_handler
 | |
| 
 | |
|         # Apply valves to the function
 | |
|         if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
 | |
|             valves = Functions.get_function_valves_by_id(filter_id)
 | |
|             function_module.valves = function_module.Valves(
 | |
|                 **(valves if valves else {})
 | |
|             )
 | |
| 
 | |
|         if hasattr(function_module, "inlet"):
 | |
|             try:
 | |
|                 inlet = function_module.inlet
 | |
| 
 | |
|                 # Create a dictionary of parameters to be passed to the function
 | |
|                 params = {"body": body} | {
 | |
|                     k: v
 | |
|                     for k, v in {
 | |
|                         **extra_params,
 | |
|                         "__model__": model,
 | |
|                         "__id__": filter_id,
 | |
|                     }.items()
 | |
|                     if k in inspect.signature(inlet).parameters
 | |
|                 }
 | |
| 
 | |
|                 if "__user__" in params and hasattr(function_module, "UserValves"):
 | |
|                     try:
 | |
|                         params["__user__"]["valves"] = function_module.UserValves(
 | |
|                             **Functions.get_user_valves_by_id_and_user_id(
 | |
|                                 filter_id, params["__user__"]["id"]
 | |
|                             )
 | |
|                         )
 | |
|                     except Exception as e:
 | |
|                         print(e)
 | |
| 
 | |
|                 if inspect.iscoroutinefunction(inlet):
 | |
|                     body = await inlet(**params)
 | |
|                 else:
 | |
|                     body = inlet(**params)
 | |
| 
 | |
|             except Exception as e:
 | |
|                 print(f"Error: {e}")
 | |
|                 raise e
 | |
| 
 | |
|     if skip_files and "files" in body.get("metadata", {}):
 | |
|         del body["metadata"]["files"]
 | |
| 
 | |
|     return body, {}
 | |
| 
 | |
| 
 | |
| async def chat_completion_tools_handler(
 | |
|     request: Request, body: dict, user: UserModel, models, extra_params: dict
 | |
| ) -> tuple[dict, dict]:
 | |
|     async def get_content_from_response(response) -> Optional[str]:
 | |
|         content = None
 | |
|         if hasattr(response, "body_iterator"):
 | |
|             async for chunk in response.body_iterator:
 | |
|                 data = json.loads(chunk.decode("utf-8"))
 | |
|                 content = data["choices"][0]["message"]["content"]
 | |
| 
 | |
|             # Cleanup any remaining background tasks if necessary
 | |
|             if response.background is not None:
 | |
|                 await response.background()
 | |
|         else:
 | |
|             content = response["choices"][0]["message"]["content"]
 | |
|         return content
 | |
| 
 | |
|     def get_tools_function_calling_payload(messages, task_model_id, content):
 | |
|         user_message = get_last_user_message(messages)
 | |
|         history = "\n".join(
 | |
|             f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
 | |
|             for message in messages[::-1][:4]
 | |
|         )
 | |
| 
 | |
|         prompt = f"History:\n{history}\nQuery: {user_message}"
 | |
| 
 | |
|         return {
 | |
|             "model": task_model_id,
 | |
|             "messages": [
 | |
|                 {"role": "system", "content": content},
 | |
|                 {"role": "user", "content": f"Query: {prompt}"},
 | |
|             ],
 | |
|             "stream": False,
 | |
|             "metadata": {"task": str(TASKS.FUNCTION_CALLING)},
 | |
|         }
 | |
| 
 | |
|     # If tool_ids field is present, call the functions
 | |
|     metadata = body.get("metadata", {})
 | |
| 
 | |
|     tool_ids = metadata.get("tool_ids", None)
 | |
|     log.debug(f"{tool_ids=}")
 | |
|     if not tool_ids:
 | |
|         return body, {}
 | |
| 
 | |
|     skip_files = False
 | |
|     sources = []
 | |
| 
 | |
|     task_model_id = get_task_model_id(
 | |
|         body["model"],
 | |
|         request.app.state.config.TASK_MODEL,
 | |
|         request.app.state.config.TASK_MODEL_EXTERNAL,
 | |
|         models,
 | |
|     )
 | |
|     tools = get_tools(
 | |
|         request,
 | |
|         tool_ids,
 | |
|         user,
 | |
|         {
 | |
|             **extra_params,
 | |
|             "__model__": models[task_model_id],
 | |
|             "__messages__": body["messages"],
 | |
|             "__files__": metadata.get("files", []),
 | |
|         },
 | |
|     )
 | |
|     log.info(f"{tools=}")
 | |
| 
 | |
|     specs = [tool["spec"] for tool in tools.values()]
 | |
|     tools_specs = json.dumps(specs)
 | |
| 
 | |
|     if request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE != "":
 | |
|         template = request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
 | |
|     else:
 | |
|         template = DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
 | |
| 
 | |
|     tools_function_calling_prompt = tools_function_calling_generation_template(
 | |
|         template, tools_specs
 | |
|     )
 | |
|     log.info(f"{tools_function_calling_prompt=}")
 | |
|     payload = get_tools_function_calling_payload(
 | |
|         body["messages"], task_model_id, tools_function_calling_prompt
 | |
|     )
 | |
| 
 | |
|     try:
 | |
|         response = await generate_chat_completion(request, form_data=payload, user=user)
 | |
|         log.debug(f"{response=}")
 | |
|         content = await get_content_from_response(response)
 | |
|         log.debug(f"{content=}")
 | |
| 
 | |
|         if not content:
 | |
|             return body, {}
 | |
| 
 | |
|         try:
 | |
|             content = content[content.find("{") : content.rfind("}") + 1]
 | |
|             if not content:
 | |
|                 raise Exception("No JSON object found in the response")
 | |
| 
 | |
|             result = json.loads(content)
 | |
| 
 | |
|             tool_function_name = result.get("name", None)
 | |
|             if tool_function_name not in tools:
 | |
|                 return body, {}
 | |
| 
 | |
|             tool_function_params = result.get("parameters", {})
 | |
| 
 | |
|             try:
 | |
|                 required_params = (
 | |
|                     tools[tool_function_name]
 | |
|                     .get("spec", {})
 | |
|                     .get("parameters", {})
 | |
|                     .get("required", [])
 | |
|                 )
 | |
|                 tool_function = tools[tool_function_name]["callable"]
 | |
|                 tool_function_params = {
 | |
|                     k: v
 | |
|                     for k, v in tool_function_params.items()
 | |
|                     if k in required_params
 | |
|                 }
 | |
|                 tool_output = await tool_function(**tool_function_params)
 | |
| 
 | |
|             except Exception as e:
 | |
|                 tool_output = str(e)
 | |
| 
 | |
|             if isinstance(tool_output, str):
 | |
|                 if tools[tool_function_name]["citation"]:
 | |
|                     sources.append(
 | |
|                         {
 | |
|                             "source": {
 | |
|                                 "name": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
 | |
|                             },
 | |
|                             "document": [tool_output],
 | |
|                             "metadata": [
 | |
|                                 {
 | |
|                                     "source": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
 | |
|                                 }
 | |
|                             ],
 | |
|                         }
 | |
|                     )
 | |
|                 else:
 | |
|                     sources.append(
 | |
|                         {
 | |
|                             "source": {},
 | |
|                             "document": [tool_output],
 | |
|                             "metadata": [
 | |
|                                 {
 | |
|                                     "source": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
 | |
|                                 }
 | |
|                             ],
 | |
|                         }
 | |
|                     )
 | |
| 
 | |
|                 if tools[tool_function_name]["file_handler"]:
 | |
|                     skip_files = True
 | |
| 
 | |
|         except Exception as e:
 | |
|             log.exception(f"Error: {e}")
 | |
|             content = None
 | |
|     except Exception as e:
 | |
|         log.exception(f"Error: {e}")
 | |
|         content = None
 | |
| 
 | |
|     log.debug(f"tool_contexts: {sources}")
 | |
| 
 | |
|     if skip_files and "files" in body.get("metadata", {}):
 | |
|         del body["metadata"]["files"]
 | |
| 
 | |
|     return body, {"sources": sources}
 | |
| 
 | |
| 
 | |
| async def chat_web_search_handler(
 | |
|     request: Request, form_data: dict, extra_params: dict, user
 | |
| ):
 | |
|     event_emitter = extra_params["__event_emitter__"]
 | |
|     await event_emitter(
 | |
|         {
 | |
|             "type": "status",
 | |
|             "data": {
 | |
|                 "action": "web_search",
 | |
|                 "description": "Generating search query",
 | |
|                 "done": False,
 | |
|             },
 | |
|         }
 | |
|     )
 | |
| 
 | |
|     messages = form_data["messages"]
 | |
|     user_message = get_last_user_message(messages)
 | |
| 
 | |
|     queries = []
 | |
|     try:
 | |
|         res = await generate_queries(
 | |
|             request,
 | |
|             {
 | |
|                 "model": form_data["model"],
 | |
|                 "messages": messages,
 | |
|                 "prompt": user_message,
 | |
|                 "type": "web_search",
 | |
|             },
 | |
|             user,
 | |
|         )
 | |
| 
 | |
|         response = res["choices"][0]["message"]["content"]
 | |
| 
 | |
|         try:
 | |
|             bracket_start = response.find("{")
 | |
|             bracket_end = response.rfind("}") + 1
 | |
| 
 | |
|             if bracket_start == -1 or bracket_end == -1:
 | |
|                 raise Exception("No JSON object found in the response")
 | |
| 
 | |
|             response = response[bracket_start:bracket_end]
 | |
|             queries = json.loads(response)
 | |
|             queries = queries.get("queries", [])
 | |
|         except Exception as e:
 | |
|             queries = [response]
 | |
| 
 | |
|     except Exception as e:
 | |
|         log.exception(e)
 | |
|         queries = [user_message]
 | |
| 
 | |
|     if len(queries) == 0:
 | |
|         await event_emitter(
 | |
|             {
 | |
|                 "type": "status",
 | |
|                 "data": {
 | |
|                     "action": "web_search",
 | |
|                     "description": "No search query generated",
 | |
|                     "done": True,
 | |
|                 },
 | |
|             }
 | |
|         )
 | |
|         return
 | |
| 
 | |
|     searchQuery = queries[0]
 | |
| 
 | |
|     await event_emitter(
 | |
|         {
 | |
|             "type": "status",
 | |
|             "data": {
 | |
|                 "action": "web_search",
 | |
|                 "description": 'Searching "{{searchQuery}}"',
 | |
|                 "query": searchQuery,
 | |
|                 "done": False,
 | |
|             },
 | |
|         }
 | |
|     )
 | |
| 
 | |
|     try:
 | |
| 
 | |
|         results = await process_web_search(
 | |
|             request,
 | |
|             SearchForm(
 | |
|                 **{
 | |
|                     "query": searchQuery,
 | |
|                 }
 | |
|             ),
 | |
|             extra_params=extra_params,
 | |
|             user=user
 | |
|         )
 | |
| 
 | |
|         if results:
 | |
|             await event_emitter(
 | |
|                 {
 | |
|                     "type": "status",
 | |
|                     "data": {
 | |
|                         "action": "web_search",
 | |
|                         "description": "Searched {{count}} sites",
 | |
|                         "query": searchQuery,
 | |
|                         "urls": results["filenames"],
 | |
|                         "done": True,
 | |
|                     },
 | |
|                 }
 | |
|             )
 | |
| 
 | |
|             files = form_data.get("files", [])
 | |
|             files.append(
 | |
|                 {
 | |
|                     "collection_name": results["collection_name"],
 | |
|                     "name": searchQuery,
 | |
|                     "type": "web_search_results",
 | |
|                     "urls": results["filenames"],
 | |
|                 }
 | |
|             )
 | |
|             form_data["files"] = files
 | |
|         else:
 | |
|             await event_emitter(
 | |
|                 {
 | |
|                     "type": "status",
 | |
|                     "data": {
 | |
|                         "action": "web_search",
 | |
|                         "description": "No search results found",
 | |
|                         "query": searchQuery,
 | |
|                         "done": True,
 | |
|                         "error": True,
 | |
|                     },
 | |
|                 }
 | |
|             )
 | |
|     except Exception as e:
 | |
|         log.exception(e)
 | |
|         await event_emitter(
 | |
|             {
 | |
|                 "type": "status",
 | |
|                 "data": {
 | |
|                     "action": "web_search",
 | |
|                     "description": 'Error searching "{{searchQuery}}"',
 | |
|                     "query": searchQuery,
 | |
|                     "done": True,
 | |
|                     "error": True,
 | |
|                 },
 | |
|             }
 | |
|         )
 | |
| 
 | |
|     return form_data
 | |
| 
 | |
| 
 | |
| async def chat_image_generation_handler(
 | |
|     request: Request, form_data: dict, extra_params: dict, user
 | |
| ):
 | |
|     __event_emitter__ = extra_params["__event_emitter__"]
 | |
|     await __event_emitter__(
 | |
|         {
 | |
|             "type": "status",
 | |
|             "data": {"description": "Generating an image", "done": False},
 | |
|         }
 | |
|     )
 | |
| 
 | |
|     messages = form_data["messages"]
 | |
|     user_message = get_last_user_message(messages)
 | |
| 
 | |
|     prompt = user_message
 | |
|     negative_prompt = ""
 | |
| 
 | |
|     if request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION:
 | |
|         try:
 | |
|             res = await generate_image_prompt(
 | |
|                 request,
 | |
|                 {
 | |
|                     "model": form_data["model"],
 | |
|                     "messages": messages,
 | |
|                 },
 | |
|                 user,
 | |
|             )
 | |
| 
 | |
|             response = res["choices"][0]["message"]["content"]
 | |
| 
 | |
|             try:
 | |
|                 bracket_start = response.find("{")
 | |
|                 bracket_end = response.rfind("}") + 1
 | |
| 
 | |
|                 if bracket_start == -1 or bracket_end == -1:
 | |
|                     raise Exception("No JSON object found in the response")
 | |
| 
 | |
|                 response = response[bracket_start:bracket_end]
 | |
|                 response = json.loads(response)
 | |
|                 prompt = response.get("prompt", [])
 | |
|             except Exception as e:
 | |
|                 prompt = user_message
 | |
| 
 | |
|         except Exception as e:
 | |
|             log.exception(e)
 | |
|             prompt = user_message
 | |
| 
 | |
|     system_message_content = ""
 | |
| 
 | |
|     try:
 | |
|         images = await image_generations(
 | |
|             request=request,
 | |
|             form_data=GenerateImageForm(**{"prompt": prompt}),
 | |
|             user=user,
 | |
|         )
 | |
| 
 | |
|         await __event_emitter__(
 | |
|             {
 | |
|                 "type": "status",
 | |
|                 "data": {"description": "Generated an image", "done": True},
 | |
|             }
 | |
|         )
 | |
| 
 | |
|         for image in images:
 | |
|             await __event_emitter__(
 | |
|                 {
 | |
|                     "type": "message",
 | |
|                     "data": {"content": f"\n"},
 | |
|                 }
 | |
|             )
 | |
| 
 | |
|         system_message_content = "<context>User is shown the generated image, tell the user that the image has been generated</context>"
 | |
|     except Exception as e:
 | |
|         log.exception(e)
 | |
|         await __event_emitter__(
 | |
|             {
 | |
|                 "type": "status",
 | |
|                 "data": {
 | |
|                     "description": f"An error occured while generating an image",
 | |
|                     "done": True,
 | |
|                 },
 | |
|             }
 | |
|         )
 | |
| 
 | |
|         system_message_content = "<context>Unable to generate an image, tell the user that an error occured</context>"
 | |
| 
 | |
|     if system_message_content:
 | |
|         form_data["messages"] = add_or_update_system_message(
 | |
|             system_message_content, form_data["messages"]
 | |
|         )
 | |
| 
 | |
|     return form_data
 | |
| 
 | |
| 
 | |
| async def chat_completion_files_handler(
 | |
|     request: Request, body: dict, user: UserModel
 | |
| ) -> tuple[dict, dict[str, list]]:
 | |
|     sources = []
 | |
| 
 | |
|     if files := body.get("metadata", {}).get("files", None):
 | |
|         try:
 | |
|             queries_response = await generate_queries(
 | |
|                 request,
 | |
|                 {
 | |
|                     "model": body["model"],
 | |
|                     "messages": body["messages"],
 | |
|                     "type": "retrieval",
 | |
|                 },
 | |
|                 user,
 | |
|             )
 | |
|             queries_response = queries_response["choices"][0]["message"]["content"]
 | |
| 
 | |
|             try:
 | |
|                 bracket_start = queries_response.find("{")
 | |
|                 bracket_end = queries_response.rfind("}") + 1
 | |
| 
 | |
|                 if bracket_start == -1 or bracket_end == -1:
 | |
|                     raise Exception("No JSON object found in the response")
 | |
| 
 | |
|                 queries_response = queries_response[bracket_start:bracket_end]
 | |
|                 queries_response = json.loads(queries_response)
 | |
|             except Exception as e:
 | |
|                 queries_response = {"queries": [queries_response]}
 | |
| 
 | |
|             queries = queries_response.get("queries", [])
 | |
|         except Exception as e:
 | |
|             queries = []
 | |
| 
 | |
|         if len(queries) == 0:
 | |
|             queries = [get_last_user_message(body["messages"])]
 | |
| 
 | |
|         try:
 | |
|             # Offload get_sources_from_files to a separate thread
 | |
|             loop = asyncio.get_running_loop()
 | |
|             with ThreadPoolExecutor() as executor:
 | |
|                 sources = await loop.run_in_executor(
 | |
|                     executor,
 | |
|                     lambda: get_sources_from_files(
 | |
|                         files=files,
 | |
|                         queries=queries,
 | |
|                         embedding_function=request.app.state.EMBEDDING_FUNCTION,
 | |
|                         k=request.app.state.config.TOP_K,
 | |
|                         reranking_function=request.app.state.rf,
 | |
|                         r=request.app.state.config.RELEVANCE_THRESHOLD,
 | |
|                         hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
 | |
|                     ),
 | |
|                 )
 | |
| 
 | |
|         except Exception as e:
 | |
|             log.exception(e)
 | |
| 
 | |
|         log.debug(f"rag_contexts:sources: {sources}")
 | |
| 
 | |
|     return body, {"sources": sources}
 | |
| 
 | |
| 
 | |
| def apply_params_to_form_data(form_data, model):
 | |
|     params = form_data.pop("params", {})
 | |
|     if model.get("ollama"):
 | |
|         form_data["options"] = params
 | |
| 
 | |
|         if "format" in params:
 | |
|             form_data["format"] = params["format"]
 | |
| 
 | |
|         if "keep_alive" in params:
 | |
|             form_data["keep_alive"] = params["keep_alive"]
 | |
|     else:
 | |
|         if "seed" in params:
 | |
|             form_data["seed"] = params["seed"]
 | |
| 
 | |
|         if "stop" in params:
 | |
|             form_data["stop"] = params["stop"]
 | |
| 
 | |
|         if "temperature" in params:
 | |
|             form_data["temperature"] = params["temperature"]
 | |
| 
 | |
|         if "max_tokens" in params:
 | |
|             form_data["max_tokens"] = params["max_tokens"]
 | |
| 
 | |
|         if "top_p" in params:
 | |
|             form_data["top_p"] = params["top_p"]
 | |
| 
 | |
|         if "frequency_penalty" in params:
 | |
|             form_data["frequency_penalty"] = params["frequency_penalty"]
 | |
| 
 | |
|         if "reasoning_effort" in params:
 | |
|             form_data["reasoning_effort"] = params["reasoning_effort"]
 | |
| 
 | |
|     return form_data
 | |
| 
 | |
| 
 | |
| async def process_chat_payload(request, form_data, metadata, user, model):
 | |
|     form_data = apply_params_to_form_data(form_data, model)
 | |
|     log.debug(f"form_data: {form_data}")
 | |
| 
 | |
|     event_emitter = get_event_emitter(metadata)
 | |
|     event_call = get_event_call(metadata)
 | |
| 
 | |
|     extra_params = {
 | |
|         "__event_emitter__": event_emitter,
 | |
|         "__event_call__": event_call,
 | |
|         "__user__": {
 | |
|             "id": user.id,
 | |
|             "email": user.email,
 | |
|             "name": user.name,
 | |
|             "role": user.role,
 | |
|         },
 | |
|         "__metadata__": metadata,
 | |
|         "__request__": request,
 | |
|     }
 | |
| 
 | |
|     # Initialize events to store additional event to be sent to the client
 | |
|     # Initialize contexts and citation
 | |
|     models = request.app.state.MODELS
 | |
| 
 | |
|     events = []
 | |
|     sources = []
 | |
| 
 | |
|     user_message = get_last_user_message(form_data["messages"])
 | |
|     model_knowledge = model.get("info", {}).get("meta", {}).get("knowledge", False)
 | |
| 
 | |
|     if model_knowledge:
 | |
|         await event_emitter(
 | |
|             {
 | |
|                 "type": "status",
 | |
|                 "data": {
 | |
|                     "action": "knowledge_search",
 | |
|                     "query": user_message,
 | |
|                     "done": False,
 | |
|                 },
 | |
|             }
 | |
|         )
 | |
| 
 | |
|         knowledge_files = []
 | |
|         for item in model_knowledge:
 | |
|             if item.get("collection_name"):
 | |
|                 knowledge_files.append(
 | |
|                     {
 | |
|                         "id": item.get("collection_name"),
 | |
|                         "name": item.get("name"),
 | |
|                         "legacy": True,
 | |
|                     }
 | |
|                 )
 | |
|             elif item.get("collection_names"):
 | |
|                 knowledge_files.append(
 | |
|                     {
 | |
|                         "name": item.get("name"),
 | |
|                         "type": "collection",
 | |
|                         "collection_names": item.get("collection_names"),
 | |
|                         "legacy": True,
 | |
|                     }
 | |
|                 )
 | |
|             else:
 | |
|                 knowledge_files.append(item)
 | |
| 
 | |
|         files = form_data.get("files", [])
 | |
|         files.extend(knowledge_files)
 | |
|         form_data["files"] = files
 | |
| 
 | |
|     variables = form_data.pop("variables", None)
 | |
| 
 | |
|     features = form_data.pop("features", None)
 | |
|     if features:
 | |
|         if "web_search" in features and features["web_search"]:
 | |
|             form_data = await chat_web_search_handler(
 | |
|                 request, form_data, extra_params, user
 | |
|             )
 | |
| 
 | |
|         if "image_generation" in features and features["image_generation"]:
 | |
|             form_data = await chat_image_generation_handler(
 | |
|                 request, form_data, extra_params, user
 | |
|             )
 | |
| 
 | |
|     try:
 | |
|         form_data, flags = await chat_completion_filter_functions_handler(
 | |
|             request, form_data, model, extra_params
 | |
|         )
 | |
|     except Exception as e:
 | |
|         raise Exception(f"Error: {e}")
 | |
| 
 | |
|     tool_ids = form_data.pop("tool_ids", None)
 | |
|     files = form_data.pop("files", None)
 | |
|     # Remove files duplicates
 | |
|     if files:
 | |
|         files = list({json.dumps(f, sort_keys=True): f for f in files}.values())
 | |
| 
 | |
|     metadata = {
 | |
|         **metadata,
 | |
|         "tool_ids": tool_ids,
 | |
|         "files": files,
 | |
|     }
 | |
|     form_data["metadata"] = metadata
 | |
| 
 | |
|     try:
 | |
|         form_data, flags = await chat_completion_tools_handler(
 | |
|             request, form_data, user, models, extra_params
 | |
|         )
 | |
|         sources.extend(flags.get("sources", []))
 | |
|     except Exception as e:
 | |
|         log.exception(e)
 | |
| 
 | |
|     try:
 | |
|         form_data, flags = await chat_completion_files_handler(request, form_data, user)
 | |
|         sources.extend(flags.get("sources", []))
 | |
|     except Exception as e:
 | |
|         log.exception(e)
 | |
| 
 | |
|     # If context is not empty, insert it into the messages
 | |
|     if len(sources) > 0:
 | |
|         context_string = ""
 | |
|         for source_idx, source in enumerate(sources):
 | |
|             source_id = source.get("source", {}).get("name", "")
 | |
| 
 | |
|             if "document" in source:
 | |
|                 for doc_idx, doc_context in enumerate(source["document"]):
 | |
|                     metadata = source.get("metadata")
 | |
|                     doc_source_id = None
 | |
| 
 | |
|                     if metadata:
 | |
|                         doc_source_id = metadata[doc_idx].get("source", source_id)
 | |
| 
 | |
|                     if source_id:
 | |
|                         context_string += f"<source><source_id>{doc_source_id if doc_source_id is not None else source_id}</source_id><source_context>{doc_context}</source_context></source>\n"
 | |
|                     else:
 | |
|                         # If there is no source_id, then do not include the source_id tag
 | |
|                         context_string += f"<source><source_context>{doc_context}</source_context></source>\n"
 | |
| 
 | |
|         context_string = context_string.strip()
 | |
|         prompt = get_last_user_message(form_data["messages"])
 | |
| 
 | |
|         if prompt is None:
 | |
|             raise Exception("No user message found")
 | |
|         if (
 | |
|             request.app.state.config.RELEVANCE_THRESHOLD == 0
 | |
|             and context_string.strip() == ""
 | |
|         ):
 | |
|             log.debug(
 | |
|                 f"With a 0 relevancy threshold for RAG, the context cannot be empty"
 | |
|             )
 | |
| 
 | |
|         # Workaround for Ollama 2.0+ system prompt issue
 | |
|         # TODO: replace with add_or_update_system_message
 | |
|         if model["owned_by"] == "ollama":
 | |
|             form_data["messages"] = prepend_to_first_user_message_content(
 | |
|                 rag_template(
 | |
|                     request.app.state.config.RAG_TEMPLATE, context_string, prompt
 | |
|                 ),
 | |
|                 form_data["messages"],
 | |
|             )
 | |
|         else:
 | |
|             form_data["messages"] = add_or_update_system_message(
 | |
|                 rag_template(
 | |
|                     request.app.state.config.RAG_TEMPLATE, context_string, prompt
 | |
|                 ),
 | |
|                 form_data["messages"],
 | |
|             )
 | |
| 
 | |
|     # If there are citations, add them to the data_items
 | |
|     sources = [source for source in sources if source.get("source", {}).get("name", "")]
 | |
| 
 | |
|     if len(sources) > 0:
 | |
|         events.append({"sources": sources})
 | |
| 
 | |
|     if model_knowledge:
 | |
|         await event_emitter(
 | |
|             {
 | |
|                 "type": "status",
 | |
|                 "data": {
 | |
|                     "action": "knowledge_search",
 | |
|                     "query": user_message,
 | |
|                     "done": True,
 | |
|                     "hidden": True,
 | |
|                 },
 | |
|             }
 | |
|         )
 | |
| 
 | |
|     return form_data, events
 | |
| 
 | |
| 
 | |
| async def process_chat_response(
 | |
|     request, response, form_data, user, events, metadata, tasks
 | |
| ):
 | |
|     async def background_tasks_handler():
 | |
|         message_map = Chats.get_messages_by_chat_id(metadata["chat_id"])
 | |
|         message = message_map.get(metadata["message_id"]) if message_map else None
 | |
| 
 | |
|         if message:
 | |
|             messages = get_message_list(message_map, message.get("id"))
 | |
| 
 | |
|             if tasks:
 | |
|                 if TASKS.TITLE_GENERATION in tasks:
 | |
|                     if tasks[TASKS.TITLE_GENERATION]:
 | |
|                         res = await generate_title(
 | |
|                             request,
 | |
|                             {
 | |
|                                 "model": message["model"],
 | |
|                                 "messages": messages,
 | |
|                                 "chat_id": metadata["chat_id"],
 | |
|                             },
 | |
|                             user,
 | |
|                         )
 | |
| 
 | |
|                         if res and isinstance(res, dict):
 | |
|                             if len(res.get("choices", [])) == 1:
 | |
|                                 title_string = (
 | |
|                                     res.get("choices", [])[0]
 | |
|                                     .get("message", {})
 | |
|                                     .get("content", message.get("content", "New Chat"))
 | |
|                                 )
 | |
|                             else:
 | |
|                                 title_string = ""
 | |
| 
 | |
|                             title_string = title_string[
 | |
|                                 title_string.find("{") : title_string.rfind("}") + 1
 | |
|                             ]
 | |
| 
 | |
|                             try:
 | |
|                                 title = json.loads(title_string).get(
 | |
|                                     "title", "New Chat"
 | |
|                                 )
 | |
|                             except Exception as e:
 | |
|                                 title = ""
 | |
| 
 | |
|                             if not title:
 | |
|                                 title = messages[0].get("content", "New Chat")
 | |
| 
 | |
|                             Chats.update_chat_title_by_id(metadata["chat_id"], title)
 | |
| 
 | |
|                             await event_emitter(
 | |
|                                 {
 | |
|                                     "type": "chat:title",
 | |
|                                     "data": title,
 | |
|                                 }
 | |
|                             )
 | |
|                     elif len(messages) == 2:
 | |
|                         title = messages[0].get("content", "New Chat")
 | |
| 
 | |
|                         Chats.update_chat_title_by_id(metadata["chat_id"], title)
 | |
| 
 | |
|                         await event_emitter(
 | |
|                             {
 | |
|                                 "type": "chat:title",
 | |
|                                 "data": message.get("content", "New Chat"),
 | |
|                             }
 | |
|                         )
 | |
| 
 | |
|                 if TASKS.TAGS_GENERATION in tasks and tasks[TASKS.TAGS_GENERATION]:
 | |
|                     res = await generate_chat_tags(
 | |
|                         request,
 | |
|                         {
 | |
|                             "model": message["model"],
 | |
|                             "messages": messages,
 | |
|                             "chat_id": metadata["chat_id"],
 | |
|                         },
 | |
|                         user,
 | |
|                     )
 | |
| 
 | |
|                     if res and isinstance(res, dict):
 | |
|                         if len(res.get("choices", [])) == 1:
 | |
|                             tags_string = (
 | |
|                                 res.get("choices", [])[0]
 | |
|                                 .get("message", {})
 | |
|                                 .get("content", "")
 | |
|                             )
 | |
|                         else:
 | |
|                             tags_string = ""
 | |
| 
 | |
|                         tags_string = tags_string[
 | |
|                             tags_string.find("{") : tags_string.rfind("}") + 1
 | |
|                         ]
 | |
| 
 | |
|                         try:
 | |
|                             tags = json.loads(tags_string).get("tags", [])
 | |
|                             Chats.update_chat_tags_by_id(
 | |
|                                 metadata["chat_id"], tags, user
 | |
|                             )
 | |
| 
 | |
|                             await event_emitter(
 | |
|                                 {
 | |
|                                     "type": "chat:tags",
 | |
|                                     "data": tags,
 | |
|                                 }
 | |
|                             )
 | |
|                         except Exception as e:
 | |
|                             pass
 | |
| 
 | |
|     event_emitter = None
 | |
|     if (
 | |
|         "session_id" in metadata
 | |
|         and metadata["session_id"]
 | |
|         and "chat_id" in metadata
 | |
|         and metadata["chat_id"]
 | |
|         and "message_id" in metadata
 | |
|         and metadata["message_id"]
 | |
|     ):
 | |
|         event_emitter = get_event_emitter(metadata)
 | |
| 
 | |
|     if not isinstance(response, StreamingResponse):
 | |
|         if event_emitter:
 | |
| 
 | |
|             if "selected_model_id" in response:
 | |
|                 Chats.upsert_message_to_chat_by_id_and_message_id(
 | |
|                     metadata["chat_id"],
 | |
|                     metadata["message_id"],
 | |
|                     {
 | |
|                         "selectedModelId": response["selected_model_id"],
 | |
|                     },
 | |
|                 )
 | |
| 
 | |
|             if response.get("choices", [])[0].get("message", {}).get("content"):
 | |
|                 content = response["choices"][0]["message"]["content"]
 | |
| 
 | |
|                 if content:
 | |
| 
 | |
|                     await event_emitter(
 | |
|                         {
 | |
|                             "type": "chat:completion",
 | |
|                             "data": response,
 | |
|                         }
 | |
|                     )
 | |
| 
 | |
|                     title = Chats.get_chat_title_by_id(metadata["chat_id"])
 | |
| 
 | |
|                     await event_emitter(
 | |
|                         {
 | |
|                             "type": "chat:completion",
 | |
|                             "data": {
 | |
|                                 "done": True,
 | |
|                                 "content": content,
 | |
|                                 "title": title,
 | |
|                             },
 | |
|                         }
 | |
|                     )
 | |
| 
 | |
|                     # Save message in the database
 | |
|                     Chats.upsert_message_to_chat_by_id_and_message_id(
 | |
|                         metadata["chat_id"],
 | |
|                         metadata["message_id"],
 | |
|                         {
 | |
|                             "content": content,
 | |
|                         },
 | |
|                     )
 | |
| 
 | |
|                     # Send a webhook notification if the user is not active
 | |
|                     if get_active_status_by_user_id(user.id) is None:
 | |
|                         webhook_url = Users.get_user_webhook_url_by_id(user.id)
 | |
|                         if webhook_url:
 | |
|                             post_webhook(
 | |
|                                 webhook_url,
 | |
|                                 f"{title} - {request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}\n\n{content}",
 | |
|                                 {
 | |
|                                     "action": "chat",
 | |
|                                     "message": content,
 | |
|                                     "title": title,
 | |
|                                     "url": f"{request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}",
 | |
|                                 },
 | |
|                             )
 | |
| 
 | |
|                     await background_tasks_handler()
 | |
| 
 | |
|             return response
 | |
|         else:
 | |
|             return response
 | |
| 
 | |
|     if not any(
 | |
|         content_type in response.headers["Content-Type"]
 | |
|         for content_type in ["text/event-stream", "application/x-ndjson"]
 | |
|     ):
 | |
|         return response
 | |
| 
 | |
|     if event_emitter:
 | |
| 
 | |
|         task_id = str(uuid4())  # Create a unique task ID.
 | |
| 
 | |
|         # Handle as a background task
 | |
|         async def post_response_handler(response, events):
 | |
|             message = Chats.get_message_by_id_and_message_id(
 | |
|                 metadata["chat_id"], metadata["message_id"]
 | |
|             )
 | |
|             content = message.get("content", "") if message else ""
 | |
| 
 | |
|             try:
 | |
|                 for event in events:
 | |
|                     await event_emitter(
 | |
|                         {
 | |
|                             "type": "chat:completion",
 | |
|                             "data": event,
 | |
|                         }
 | |
|                     )
 | |
| 
 | |
|                     # Save message in the database
 | |
|                     Chats.upsert_message_to_chat_by_id_and_message_id(
 | |
|                         metadata["chat_id"],
 | |
|                         metadata["message_id"],
 | |
|                         {
 | |
|                             **event,
 | |
|                         },
 | |
|                     )
 | |
| 
 | |
|                 # We might want to disable this by default
 | |
|                 detect_reasoning = True
 | |
|                 reasoning_tags = ["think", "reason", "reasoning", "thought", "Thought"]
 | |
|                 current_tag = None
 | |
| 
 | |
|                 reasoning_start_time = None
 | |
| 
 | |
|                 reasoning_content = ""
 | |
|                 ongoing_content = ""
 | |
| 
 | |
|                 async for line in response.body_iterator:
 | |
|                     line = line.decode("utf-8") if isinstance(line, bytes) else line
 | |
|                     data = line
 | |
| 
 | |
|                     # Skip empty lines
 | |
|                     if not data.strip():
 | |
|                         continue
 | |
| 
 | |
|                     # "data:" is the prefix for each event
 | |
|                     if not data.startswith("data:"):
 | |
|                         continue
 | |
| 
 | |
|                     # Remove the prefix
 | |
|                     data = data[len("data:") :].strip()
 | |
| 
 | |
|                     try:
 | |
|                         data = json.loads(data)
 | |
| 
 | |
|                         if "selected_model_id" in data:
 | |
|                             Chats.upsert_message_to_chat_by_id_and_message_id(
 | |
|                                 metadata["chat_id"],
 | |
|                                 metadata["message_id"],
 | |
|                                 {
 | |
|                                     "selectedModelId": data["selected_model_id"],
 | |
|                                 },
 | |
|                             )
 | |
|                         else:
 | |
|                             value = (
 | |
|                                 data.get("choices", [])[0]
 | |
|                                 .get("delta", {})
 | |
|                                 .get("content")
 | |
|                             )
 | |
| 
 | |
|                             if value:
 | |
|                                 content = f"{content}{value}"
 | |
| 
 | |
|                                 if detect_reasoning:
 | |
|                                     for tag in reasoning_tags:
 | |
|                                         start_tag = f"<{tag}>\n"
 | |
|                                         end_tag = f"</{tag}>\n"
 | |
| 
 | |
|                                         if start_tag in content:
 | |
|                                             # Remove the start tag
 | |
|                                             content = content.replace(start_tag, "")
 | |
|                                             ongoing_content = content
 | |
| 
 | |
|                                             reasoning_start_time = time.time()
 | |
|                                             reasoning_content = ""
 | |
| 
 | |
|                                             current_tag = tag
 | |
|                                             break
 | |
| 
 | |
|                                     if reasoning_start_time is not None:
 | |
|                                         # Remove the last value from the content
 | |
|                                         content = content[: -len(value)]
 | |
| 
 | |
|                                         reasoning_content += value
 | |
| 
 | |
|                                         end_tag = f"</{current_tag}>\n"
 | |
|                                         if end_tag in reasoning_content:
 | |
|                                             reasoning_end_time = time.time()
 | |
|                                             reasoning_duration = int(
 | |
|                                                 reasoning_end_time
 | |
|                                                 - reasoning_start_time
 | |
|                                             )
 | |
|                                             reasoning_content = (
 | |
|                                                 reasoning_content.strip(
 | |
|                                                     f"<{current_tag}>\n"
 | |
|                                                 )
 | |
|                                                 .strip(end_tag)
 | |
|                                                 .strip()
 | |
|                                             )
 | |
| 
 | |
|                                             if reasoning_content:
 | |
|                                                 reasoning_display_content = "\n".join(
 | |
|                                                     (
 | |
|                                                         f"> {line}"
 | |
|                                                         if not line.startswith(">")
 | |
|                                                         else line
 | |
|                                                     )
 | |
|                                                     for line in reasoning_content.splitlines()
 | |
|                                                 )
 | |
| 
 | |
|                                                 # Format reasoning with <details> tag
 | |
|                                                 content = f'{ongoing_content}<details type="reasoning" done="true" duration="{reasoning_duration}">\n<summary>Thought for {reasoning_duration} seconds</summary>\n{reasoning_display_content}\n</details>\n'
 | |
|                                             else:
 | |
|                                                 content = ""
 | |
| 
 | |
|                                             reasoning_start_time = None
 | |
|                                         else:
 | |
| 
 | |
|                                             reasoning_display_content = "\n".join(
 | |
|                                                 (
 | |
|                                                     f"> {line}"
 | |
|                                                     if not line.startswith(">")
 | |
|                                                     else line
 | |
|                                                 )
 | |
|                                                 for line in reasoning_content.splitlines()
 | |
|                                             )
 | |
| 
 | |
|                                             # Show ongoing thought process
 | |
|                                             content = f'{ongoing_content}<details type="reasoning" done="false">\n<summary>Thinking…</summary>\n{reasoning_display_content}\n</details>\n'
 | |
| 
 | |
|                                 if ENABLE_REALTIME_CHAT_SAVE:
 | |
|                                     # Save message in the database
 | |
|                                     Chats.upsert_message_to_chat_by_id_and_message_id(
 | |
|                                         metadata["chat_id"],
 | |
|                                         metadata["message_id"],
 | |
|                                         {
 | |
|                                             "content": content,
 | |
|                                         },
 | |
|                                     )
 | |
|                                 else:
 | |
|                                     data = {
 | |
|                                         "content": content,
 | |
|                                     }
 | |
| 
 | |
|                         await event_emitter(
 | |
|                             {
 | |
|                                 "type": "chat:completion",
 | |
|                                 "data": data,
 | |
|                             }
 | |
|                         )
 | |
|                     except Exception as e:
 | |
|                         done = "data: [DONE]" in line
 | |
|                         if done:
 | |
|                             pass
 | |
|                         else:
 | |
|                             continue
 | |
| 
 | |
|                 title = Chats.get_chat_title_by_id(metadata["chat_id"])
 | |
|                 data = {"done": True, "content": content, "title": title}
 | |
| 
 | |
|                 if not ENABLE_REALTIME_CHAT_SAVE:
 | |
|                     # Save message in the database
 | |
|                     Chats.upsert_message_to_chat_by_id_and_message_id(
 | |
|                         metadata["chat_id"],
 | |
|                         metadata["message_id"],
 | |
|                         {
 | |
|                             "content": content,
 | |
|                         },
 | |
|                     )
 | |
| 
 | |
|                 # Send a webhook notification if the user is not active
 | |
|                 if get_active_status_by_user_id(user.id) is None:
 | |
|                     webhook_url = Users.get_user_webhook_url_by_id(user.id)
 | |
|                     if webhook_url:
 | |
|                         post_webhook(
 | |
|                             webhook_url,
 | |
|                             f"{title} - {request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}\n\n{content}",
 | |
|                             {
 | |
|                                 "action": "chat",
 | |
|                                 "message": content,
 | |
|                                 "title": title,
 | |
|                                 "url": f"{request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}",
 | |
|                             },
 | |
|                         )
 | |
| 
 | |
|                 await event_emitter(
 | |
|                     {
 | |
|                         "type": "chat:completion",
 | |
|                         "data": data,
 | |
|                     }
 | |
|                 )
 | |
| 
 | |
|                 await background_tasks_handler()
 | |
|             except asyncio.CancelledError:
 | |
|                 print("Task was cancelled!")
 | |
|                 await event_emitter({"type": "task-cancelled"})
 | |
| 
 | |
|                 if not ENABLE_REALTIME_CHAT_SAVE:
 | |
|                     # Save message in the database
 | |
|                     Chats.upsert_message_to_chat_by_id_and_message_id(
 | |
|                         metadata["chat_id"],
 | |
|                         metadata["message_id"],
 | |
|                         {
 | |
|                             "content": content,
 | |
|                         },
 | |
|                     )
 | |
| 
 | |
|             if response.background is not None:
 | |
|                 await response.background()
 | |
| 
 | |
|         # background_tasks.add_task(post_response_handler, response, events)
 | |
|         task_id, _ = create_task(post_response_handler(response, events))
 | |
|         return {"status": True, "task_id": task_id}
 | |
| 
 | |
|     else:
 | |
| 
 | |
|         # Fallback to the original response
 | |
|         async def stream_wrapper(original_generator, events):
 | |
|             def wrap_item(item):
 | |
|                 return f"data: {item}\n\n"
 | |
| 
 | |
|             for event in events:
 | |
|                 yield wrap_item(json.dumps(event))
 | |
| 
 | |
|             async for data in original_generator:
 | |
|                 yield data
 | |
| 
 | |
|         return StreamingResponse(
 | |
|             stream_wrapper(response.body_iterator, events),
 | |
|             headers=dict(response.headers),
 | |
|             background=response.background,
 | |
|         )
 |