1788 lines
		
	
	
		
			67 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			1788 lines
		
	
	
		
			67 KiB
		
	
	
	
		
			Python
		
	
	
	
| import time
 | |
| import logging
 | |
| import sys
 | |
| import os
 | |
| import base64
 | |
| 
 | |
| import asyncio
 | |
| from aiocache import cached
 | |
| from typing import Any, Optional
 | |
| import random
 | |
| import json
 | |
| import html
 | |
| import inspect
 | |
| import re
 | |
| import ast
 | |
| 
 | |
| from uuid import uuid4
 | |
| from concurrent.futures import ThreadPoolExecutor
 | |
| 
 | |
| 
 | |
| from fastapi import Request
 | |
| from fastapi import BackgroundTasks
 | |
| 
 | |
| from starlette.responses import Response, StreamingResponse
 | |
| 
 | |
| 
 | |
| from open_webui.models.chats import Chats
 | |
| from open_webui.models.users import Users
 | |
| from open_webui.socket.main import (
 | |
|     get_event_call,
 | |
|     get_event_emitter,
 | |
|     get_active_status_by_user_id,
 | |
| )
 | |
| from open_webui.routers.tasks import (
 | |
|     generate_queries,
 | |
|     generate_title,
 | |
|     generate_image_prompt,
 | |
|     generate_chat_tags,
 | |
| )
 | |
| from open_webui.routers.retrieval import process_web_search, SearchForm
 | |
| from open_webui.routers.images import image_generations, GenerateImageForm
 | |
| 
 | |
| 
 | |
| from open_webui.utils.webhook import post_webhook
 | |
| 
 | |
| 
 | |
| from open_webui.models.users import UserModel
 | |
| from open_webui.models.functions import Functions
 | |
| from open_webui.models.models import Models
 | |
| 
 | |
| from open_webui.retrieval.utils import get_sources_from_files
 | |
| 
 | |
| 
 | |
| from open_webui.utils.chat import generate_chat_completion
 | |
| from open_webui.utils.task import (
 | |
|     get_task_model_id,
 | |
|     rag_template,
 | |
|     tools_function_calling_generation_template,
 | |
| )
 | |
| from open_webui.utils.misc import (
 | |
|     deep_update,
 | |
|     get_message_list,
 | |
|     add_or_update_system_message,
 | |
|     add_or_update_user_message,
 | |
|     get_last_user_message,
 | |
|     get_last_assistant_message,
 | |
|     prepend_to_first_user_message_content,
 | |
| )
 | |
| from open_webui.utils.tools import get_tools
 | |
| from open_webui.utils.plugin import load_function_module_by_id
 | |
| 
 | |
| 
 | |
| from open_webui.tasks import create_task
 | |
| 
 | |
| from open_webui.config import (
 | |
|     CACHE_DIR,
 | |
|     DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
 | |
|     DEFAULT_CODE_INTERPRETER_PROMPT,
 | |
| )
 | |
| from open_webui.env import (
 | |
|     SRC_LOG_LEVELS,
 | |
|     GLOBAL_LOG_LEVEL,
 | |
|     BYPASS_MODEL_ACCESS_CONTROL,
 | |
|     ENABLE_REALTIME_CHAT_SAVE,
 | |
| )
 | |
| from open_webui.constants import TASKS
 | |
| 
 | |
| 
 | |
| logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
 | |
| log = logging.getLogger(__name__)
 | |
| log.setLevel(SRC_LOG_LEVELS["MAIN"])
 | |
| 
 | |
| 
 | |
| async def chat_completion_filter_functions_handler(request, body, model, extra_params):
 | |
|     skip_files = None
 | |
| 
 | |
|     def get_filter_function_ids(model):
 | |
|         def get_priority(function_id):
 | |
|             function = Functions.get_function_by_id(function_id)
 | |
|             if function is not None and hasattr(function, "valves"):
 | |
|                 # TODO: Fix FunctionModel
 | |
|                 return (function.valves if function.valves else {}).get("priority", 0)
 | |
|             return 0
 | |
| 
 | |
|         filter_ids = [
 | |
|             function.id for function in Functions.get_global_filter_functions()
 | |
|         ]
 | |
|         if "info" in model and "meta" in model["info"]:
 | |
|             filter_ids.extend(model["info"]["meta"].get("filterIds", []))
 | |
|             filter_ids = list(set(filter_ids))
 | |
| 
 | |
|         enabled_filter_ids = [
 | |
|             function.id
 | |
|             for function in Functions.get_functions_by_type("filter", active_only=True)
 | |
|         ]
 | |
| 
 | |
|         filter_ids = [
 | |
|             filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
 | |
|         ]
 | |
| 
 | |
|         filter_ids.sort(key=get_priority)
 | |
|         return filter_ids
 | |
| 
 | |
|     filter_ids = get_filter_function_ids(model)
 | |
|     for filter_id in filter_ids:
 | |
|         filter = Functions.get_function_by_id(filter_id)
 | |
|         if not filter:
 | |
|             continue
 | |
| 
 | |
|         if filter_id in request.app.state.FUNCTIONS:
 | |
|             function_module = request.app.state.FUNCTIONS[filter_id]
 | |
|         else:
 | |
|             function_module, _, _ = load_function_module_by_id(filter_id)
 | |
|             request.app.state.FUNCTIONS[filter_id] = function_module
 | |
| 
 | |
|         # Check if the function has a file_handler variable
 | |
|         if hasattr(function_module, "file_handler"):
 | |
|             skip_files = function_module.file_handler
 | |
| 
 | |
|         # Apply valves to the function
 | |
|         if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
 | |
|             valves = Functions.get_function_valves_by_id(filter_id)
 | |
|             function_module.valves = function_module.Valves(
 | |
|                 **(valves if valves else {})
 | |
|             )
 | |
| 
 | |
|         if hasattr(function_module, "inlet"):
 | |
|             try:
 | |
|                 inlet = function_module.inlet
 | |
| 
 | |
|                 # Create a dictionary of parameters to be passed to the function
 | |
|                 params = {"body": body} | {
 | |
|                     k: v
 | |
|                     for k, v in {
 | |
|                         **extra_params,
 | |
|                         "__model__": model,
 | |
|                         "__id__": filter_id,
 | |
|                     }.items()
 | |
|                     if k in inspect.signature(inlet).parameters
 | |
|                 }
 | |
| 
 | |
|                 if "__user__" in params and hasattr(function_module, "UserValves"):
 | |
|                     try:
 | |
|                         params["__user__"]["valves"] = function_module.UserValves(
 | |
|                             **Functions.get_user_valves_by_id_and_user_id(
 | |
|                                 filter_id, params["__user__"]["id"]
 | |
|                             )
 | |
|                         )
 | |
|                     except Exception as e:
 | |
|                         print(e)
 | |
| 
 | |
|                 if inspect.iscoroutinefunction(inlet):
 | |
|                     body = await inlet(**params)
 | |
|                 else:
 | |
|                     body = inlet(**params)
 | |
| 
 | |
|             except Exception as e:
 | |
|                 print(f"Error: {e}")
 | |
|                 raise e
 | |
| 
 | |
|     if skip_files and "files" in body.get("metadata", {}):
 | |
|         del body["metadata"]["files"]
 | |
| 
 | |
|     return body, {}
 | |
| 
 | |
| 
 | |
| async def chat_completion_tools_handler(
 | |
|     request: Request, body: dict, user: UserModel, models, tools
 | |
| ) -> tuple[dict, dict]:
 | |
|     async def get_content_from_response(response) -> Optional[str]:
 | |
|         content = None
 | |
|         if hasattr(response, "body_iterator"):
 | |
|             async for chunk in response.body_iterator:
 | |
|                 data = json.loads(chunk.decode("utf-8"))
 | |
|                 content = data["choices"][0]["message"]["content"]
 | |
| 
 | |
|             # Cleanup any remaining background tasks if necessary
 | |
|             if response.background is not None:
 | |
|                 await response.background()
 | |
|         else:
 | |
|             content = response["choices"][0]["message"]["content"]
 | |
|         return content
 | |
| 
 | |
|     def get_tools_function_calling_payload(messages, task_model_id, content):
 | |
|         user_message = get_last_user_message(messages)
 | |
|         history = "\n".join(
 | |
|             f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
 | |
|             for message in messages[::-1][:4]
 | |
|         )
 | |
| 
 | |
|         prompt = f"History:\n{history}\nQuery: {user_message}"
 | |
| 
 | |
|         return {
 | |
|             "model": task_model_id,
 | |
|             "messages": [
 | |
|                 {"role": "system", "content": content},
 | |
|                 {"role": "user", "content": f"Query: {prompt}"},
 | |
|             ],
 | |
|             "stream": False,
 | |
|             "metadata": {"task": str(TASKS.FUNCTION_CALLING)},
 | |
|         }
 | |
| 
 | |
|     task_model_id = get_task_model_id(
 | |
|         body["model"],
 | |
|         request.app.state.config.TASK_MODEL,
 | |
|         request.app.state.config.TASK_MODEL_EXTERNAL,
 | |
|         models,
 | |
|     )
 | |
| 
 | |
|     skip_files = False
 | |
|     sources = []
 | |
| 
 | |
|     specs = [tool["spec"] for tool in tools.values()]
 | |
|     tools_specs = json.dumps(specs)
 | |
| 
 | |
|     if request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE != "":
 | |
|         template = request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
 | |
|     else:
 | |
|         template = DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
 | |
| 
 | |
|     tools_function_calling_prompt = tools_function_calling_generation_template(
 | |
|         template, tools_specs
 | |
|     )
 | |
|     log.info(f"{tools_function_calling_prompt=}")
 | |
|     payload = get_tools_function_calling_payload(
 | |
|         body["messages"], task_model_id, tools_function_calling_prompt
 | |
|     )
 | |
| 
 | |
|     try:
 | |
|         response = await generate_chat_completion(request, form_data=payload, user=user)
 | |
|         log.debug(f"{response=}")
 | |
|         content = await get_content_from_response(response)
 | |
|         log.debug(f"{content=}")
 | |
| 
 | |
|         if not content:
 | |
|             return body, {}
 | |
| 
 | |
|         try:
 | |
|             content = content[content.find("{") : content.rfind("}") + 1]
 | |
|             if not content:
 | |
|                 raise Exception("No JSON object found in the response")
 | |
| 
 | |
|             result = json.loads(content)
 | |
| 
 | |
|             async def tool_call_handler(tool_call):
 | |
|                 nonlocal skip_files
 | |
| 
 | |
|                 log.debug(f"{tool_call=}")
 | |
| 
 | |
|                 tool_function_name = tool_call.get("name", None)
 | |
|                 if tool_function_name not in tools:
 | |
|                     return body, {}
 | |
| 
 | |
|                 tool_function_params = tool_call.get("parameters", {})
 | |
| 
 | |
|                 try:
 | |
|                     required_params = (
 | |
|                         tools[tool_function_name]
 | |
|                         .get("spec", {})
 | |
|                         .get("parameters", {})
 | |
|                         .get("required", [])
 | |
|                     )
 | |
|                     tool_function = tools[tool_function_name]["callable"]
 | |
|                     tool_function_params = {
 | |
|                         k: v
 | |
|                         for k, v in tool_function_params.items()
 | |
|                         if k in required_params
 | |
|                     }
 | |
|                     tool_output = await tool_function(**tool_function_params)
 | |
| 
 | |
|                 except Exception as e:
 | |
|                     tool_output = str(e)
 | |
| 
 | |
|                 if isinstance(tool_output, str):
 | |
|                     if tools[tool_function_name]["citation"]:
 | |
|                         sources.append(
 | |
|                             {
 | |
|                                 "source": {
 | |
|                                     "name": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
 | |
|                                 },
 | |
|                                 "document": [tool_output],
 | |
|                                 "metadata": [
 | |
|                                     {
 | |
|                                         "source": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
 | |
|                                     }
 | |
|                                 ],
 | |
|                             }
 | |
|                         )
 | |
|                     else:
 | |
|                         sources.append(
 | |
|                             {
 | |
|                                 "source": {},
 | |
|                                 "document": [tool_output],
 | |
|                                 "metadata": [
 | |
|                                     {
 | |
|                                         "source": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
 | |
|                                     }
 | |
|                                 ],
 | |
|                             }
 | |
|                         )
 | |
| 
 | |
|                     if tools[tool_function_name]["file_handler"]:
 | |
|                         skip_files = True
 | |
| 
 | |
|             # check if "tool_calls" in result
 | |
|             if result.get("tool_calls"):
 | |
|                 for tool_call in result.get("tool_calls"):
 | |
|                     await tool_call_handler(tool_call)
 | |
|             else:
 | |
|                 await tool_call_handler(result)
 | |
| 
 | |
|         except Exception as e:
 | |
|             log.exception(f"Error: {e}")
 | |
|             content = None
 | |
|     except Exception as e:
 | |
|         log.exception(f"Error: {e}")
 | |
|         content = None
 | |
| 
 | |
|     log.debug(f"tool_contexts: {sources}")
 | |
| 
 | |
|     if skip_files and "files" in body.get("metadata", {}):
 | |
|         del body["metadata"]["files"]
 | |
| 
 | |
|     return body, {"sources": sources}
 | |
| 
 | |
| 
 | |
| async def chat_web_search_handler(
 | |
|     request: Request, form_data: dict, extra_params: dict, user
 | |
| ):
 | |
|     event_emitter = extra_params["__event_emitter__"]
 | |
|     await event_emitter(
 | |
|         {
 | |
|             "type": "status",
 | |
|             "data": {
 | |
|                 "action": "web_search",
 | |
|                 "description": "Generating search query",
 | |
|                 "done": False,
 | |
|             },
 | |
|         }
 | |
|     )
 | |
| 
 | |
|     messages = form_data["messages"]
 | |
|     user_message = get_last_user_message(messages)
 | |
| 
 | |
|     queries = []
 | |
|     try:
 | |
|         res = await generate_queries(
 | |
|             request,
 | |
|             {
 | |
|                 "model": form_data["model"],
 | |
|                 "messages": messages,
 | |
|                 "prompt": user_message,
 | |
|                 "type": "web_search",
 | |
|             },
 | |
|             user,
 | |
|         )
 | |
| 
 | |
|         response = res["choices"][0]["message"]["content"]
 | |
| 
 | |
|         try:
 | |
|             bracket_start = response.find("{")
 | |
|             bracket_end = response.rfind("}") + 1
 | |
| 
 | |
|             if bracket_start == -1 or bracket_end == -1:
 | |
|                 raise Exception("No JSON object found in the response")
 | |
| 
 | |
|             response = response[bracket_start:bracket_end]
 | |
|             queries = json.loads(response)
 | |
|             queries = queries.get("queries", [])
 | |
|         except Exception as e:
 | |
|             queries = [response]
 | |
| 
 | |
|     except Exception as e:
 | |
|         log.exception(e)
 | |
|         queries = [user_message]
 | |
| 
 | |
|     if len(queries) == 0:
 | |
|         await event_emitter(
 | |
|             {
 | |
|                 "type": "status",
 | |
|                 "data": {
 | |
|                     "action": "web_search",
 | |
|                     "description": "No search query generated",
 | |
|                     "done": True,
 | |
|                 },
 | |
|             }
 | |
|         )
 | |
|         return form_data
 | |
| 
 | |
|     searchQuery = queries[0]
 | |
| 
 | |
|     await event_emitter(
 | |
|         {
 | |
|             "type": "status",
 | |
|             "data": {
 | |
|                 "action": "web_search",
 | |
|                 "description": 'Searching "{{searchQuery}}"',
 | |
|                 "query": searchQuery,
 | |
|                 "done": False,
 | |
|             },
 | |
|         }
 | |
|     )
 | |
| 
 | |
|     try:
 | |
| 
 | |
|         # Offload process_web_search to a separate thread
 | |
|         loop = asyncio.get_running_loop()
 | |
|         with ThreadPoolExecutor() as executor:
 | |
|             results = await loop.run_in_executor(
 | |
|                 executor,
 | |
|                 lambda: process_web_search(
 | |
|                     request,
 | |
|                     SearchForm(
 | |
|                         **{
 | |
|                             "query": searchQuery,
 | |
|                         }
 | |
|                     ),
 | |
|                     user,
 | |
|                 ),
 | |
|             )
 | |
| 
 | |
|         if results:
 | |
|             await event_emitter(
 | |
|                 {
 | |
|                     "type": "status",
 | |
|                     "data": {
 | |
|                         "action": "web_search",
 | |
|                         "description": "Searched {{count}} sites",
 | |
|                         "query": searchQuery,
 | |
|                         "urls": results["filenames"],
 | |
|                         "done": True,
 | |
|                     },
 | |
|                 }
 | |
|             )
 | |
| 
 | |
|             files = form_data.get("files", [])
 | |
|             files.append(
 | |
|                 {
 | |
|                     "collection_name": results["collection_name"],
 | |
|                     "name": searchQuery,
 | |
|                     "type": "web_search_results",
 | |
|                     "urls": results["filenames"],
 | |
|                 }
 | |
|             )
 | |
|             form_data["files"] = files
 | |
|         else:
 | |
|             await event_emitter(
 | |
|                 {
 | |
|                     "type": "status",
 | |
|                     "data": {
 | |
|                         "action": "web_search",
 | |
|                         "description": "No search results found",
 | |
|                         "query": searchQuery,
 | |
|                         "done": True,
 | |
|                         "error": True,
 | |
|                     },
 | |
|                 }
 | |
|             )
 | |
|     except Exception as e:
 | |
|         log.exception(e)
 | |
|         await event_emitter(
 | |
|             {
 | |
|                 "type": "status",
 | |
|                 "data": {
 | |
|                     "action": "web_search",
 | |
|                     "description": 'Error searching "{{searchQuery}}"',
 | |
|                     "query": searchQuery,
 | |
|                     "done": True,
 | |
|                     "error": True,
 | |
|                 },
 | |
|             }
 | |
|         )
 | |
| 
 | |
|     return form_data
 | |
| 
 | |
| 
 | |
| async def chat_image_generation_handler(
 | |
|     request: Request, form_data: dict, extra_params: dict, user
 | |
| ):
 | |
|     __event_emitter__ = extra_params["__event_emitter__"]
 | |
|     await __event_emitter__(
 | |
|         {
 | |
|             "type": "status",
 | |
|             "data": {"description": "Generating an image", "done": False},
 | |
|         }
 | |
|     )
 | |
| 
 | |
|     messages = form_data["messages"]
 | |
|     user_message = get_last_user_message(messages)
 | |
| 
 | |
|     prompt = user_message
 | |
|     negative_prompt = ""
 | |
| 
 | |
|     if request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION:
 | |
|         try:
 | |
|             res = await generate_image_prompt(
 | |
|                 request,
 | |
|                 {
 | |
|                     "model": form_data["model"],
 | |
|                     "messages": messages,
 | |
|                 },
 | |
|                 user,
 | |
|             )
 | |
| 
 | |
|             response = res["choices"][0]["message"]["content"]
 | |
| 
 | |
|             try:
 | |
|                 bracket_start = response.find("{")
 | |
|                 bracket_end = response.rfind("}") + 1
 | |
| 
 | |
|                 if bracket_start == -1 or bracket_end == -1:
 | |
|                     raise Exception("No JSON object found in the response")
 | |
| 
 | |
|                 response = response[bracket_start:bracket_end]
 | |
|                 response = json.loads(response)
 | |
|                 prompt = response.get("prompt", [])
 | |
|             except Exception as e:
 | |
|                 prompt = user_message
 | |
| 
 | |
|         except Exception as e:
 | |
|             log.exception(e)
 | |
|             prompt = user_message
 | |
| 
 | |
|     system_message_content = ""
 | |
| 
 | |
|     try:
 | |
|         images = await image_generations(
 | |
|             request=request,
 | |
|             form_data=GenerateImageForm(**{"prompt": prompt}),
 | |
|             user=user,
 | |
|         )
 | |
| 
 | |
|         await __event_emitter__(
 | |
|             {
 | |
|                 "type": "status",
 | |
|                 "data": {"description": "Generated an image", "done": True},
 | |
|             }
 | |
|         )
 | |
| 
 | |
|         for image in images:
 | |
|             await __event_emitter__(
 | |
|                 {
 | |
|                     "type": "message",
 | |
|                     "data": {"content": f"\n"},
 | |
|                 }
 | |
|             )
 | |
| 
 | |
|         system_message_content = "<context>User is shown the generated image, tell the user that the image has been generated</context>"
 | |
|     except Exception as e:
 | |
|         log.exception(e)
 | |
|         await __event_emitter__(
 | |
|             {
 | |
|                 "type": "status",
 | |
|                 "data": {
 | |
|                     "description": f"An error occured while generating an image",
 | |
|                     "done": True,
 | |
|                 },
 | |
|             }
 | |
|         )
 | |
| 
 | |
|         system_message_content = "<context>Unable to generate an image, tell the user that an error occured</context>"
 | |
| 
 | |
|     if system_message_content:
 | |
|         form_data["messages"] = add_or_update_system_message(
 | |
|             system_message_content, form_data["messages"]
 | |
|         )
 | |
| 
 | |
|     return form_data
 | |
| 
 | |
| 
 | |
| async def chat_completion_files_handler(
 | |
|     request: Request, body: dict, user: UserModel
 | |
| ) -> tuple[dict, dict[str, list]]:
 | |
|     sources = []
 | |
| 
 | |
|     if files := body.get("metadata", {}).get("files", None):
 | |
|         try:
 | |
|             queries_response = await generate_queries(
 | |
|                 request,
 | |
|                 {
 | |
|                     "model": body["model"],
 | |
|                     "messages": body["messages"],
 | |
|                     "type": "retrieval",
 | |
|                 },
 | |
|                 user,
 | |
|             )
 | |
|             queries_response = queries_response["choices"][0]["message"]["content"]
 | |
| 
 | |
|             try:
 | |
|                 bracket_start = queries_response.find("{")
 | |
|                 bracket_end = queries_response.rfind("}") + 1
 | |
| 
 | |
|                 if bracket_start == -1 or bracket_end == -1:
 | |
|                     raise Exception("No JSON object found in the response")
 | |
| 
 | |
|                 queries_response = queries_response[bracket_start:bracket_end]
 | |
|                 queries_response = json.loads(queries_response)
 | |
|             except Exception as e:
 | |
|                 queries_response = {"queries": [queries_response]}
 | |
| 
 | |
|             queries = queries_response.get("queries", [])
 | |
|         except Exception as e:
 | |
|             queries = []
 | |
| 
 | |
|         if len(queries) == 0:
 | |
|             queries = [get_last_user_message(body["messages"])]
 | |
| 
 | |
|         try:
 | |
|             # Offload get_sources_from_files to a separate thread
 | |
|             loop = asyncio.get_running_loop()
 | |
|             with ThreadPoolExecutor() as executor:
 | |
|                 sources = await loop.run_in_executor(
 | |
|                     executor,
 | |
|                     lambda: get_sources_from_files(
 | |
|                         files=files,
 | |
|                         queries=queries,
 | |
|                         embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION(
 | |
|                             query, user=user
 | |
|                         ),
 | |
|                         k=request.app.state.config.TOP_K,
 | |
|                         reranking_function=request.app.state.rf,
 | |
|                         r=request.app.state.config.RELEVANCE_THRESHOLD,
 | |
|                         hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
 | |
|                     ),
 | |
|                 )
 | |
| 
 | |
|         except Exception as e:
 | |
|             log.exception(e)
 | |
| 
 | |
|         log.debug(f"rag_contexts:sources: {sources}")
 | |
| 
 | |
|     return body, {"sources": sources}
 | |
| 
 | |
| 
 | |
| def apply_params_to_form_data(form_data, model):
 | |
|     params = form_data.pop("params", {})
 | |
|     if model.get("ollama"):
 | |
|         form_data["options"] = params
 | |
| 
 | |
|         if "format" in params:
 | |
|             form_data["format"] = params["format"]
 | |
| 
 | |
|         if "keep_alive" in params:
 | |
|             form_data["keep_alive"] = params["keep_alive"]
 | |
|     else:
 | |
|         if "seed" in params:
 | |
|             form_data["seed"] = params["seed"]
 | |
| 
 | |
|         if "stop" in params:
 | |
|             form_data["stop"] = params["stop"]
 | |
| 
 | |
|         if "temperature" in params:
 | |
|             form_data["temperature"] = params["temperature"]
 | |
| 
 | |
|         if "max_tokens" in params:
 | |
|             form_data["max_tokens"] = params["max_tokens"]
 | |
| 
 | |
|         if "top_p" in params:
 | |
|             form_data["top_p"] = params["top_p"]
 | |
| 
 | |
|         if "frequency_penalty" in params:
 | |
|             form_data["frequency_penalty"] = params["frequency_penalty"]
 | |
| 
 | |
|         if "reasoning_effort" in params:
 | |
|             form_data["reasoning_effort"] = params["reasoning_effort"]
 | |
| 
 | |
|     return form_data
 | |
| 
 | |
| 
 | |
| async def process_chat_payload(request, form_data, metadata, user, model):
 | |
| 
 | |
|     form_data = apply_params_to_form_data(form_data, model)
 | |
|     log.debug(f"form_data: {form_data}")
 | |
| 
 | |
|     event_emitter = get_event_emitter(metadata)
 | |
|     event_call = get_event_call(metadata)
 | |
| 
 | |
|     extra_params = {
 | |
|         "__event_emitter__": event_emitter,
 | |
|         "__event_call__": event_call,
 | |
|         "__user__": {
 | |
|             "id": user.id,
 | |
|             "email": user.email,
 | |
|             "name": user.name,
 | |
|             "role": user.role,
 | |
|         },
 | |
|         "__metadata__": metadata,
 | |
|         "__request__": request,
 | |
|     }
 | |
| 
 | |
|     # Initialize events to store additional event to be sent to the client
 | |
|     # Initialize contexts and citation
 | |
|     models = request.app.state.MODELS
 | |
|     task_model_id = get_task_model_id(
 | |
|         form_data["model"],
 | |
|         request.app.state.config.TASK_MODEL,
 | |
|         request.app.state.config.TASK_MODEL_EXTERNAL,
 | |
|         models,
 | |
|     )
 | |
| 
 | |
|     events = []
 | |
|     sources = []
 | |
| 
 | |
|     user_message = get_last_user_message(form_data["messages"])
 | |
|     model_knowledge = model.get("info", {}).get("meta", {}).get("knowledge", False)
 | |
| 
 | |
|     if model_knowledge:
 | |
|         await event_emitter(
 | |
|             {
 | |
|                 "type": "status",
 | |
|                 "data": {
 | |
|                     "action": "knowledge_search",
 | |
|                     "query": user_message,
 | |
|                     "done": False,
 | |
|                 },
 | |
|             }
 | |
|         )
 | |
| 
 | |
|         knowledge_files = []
 | |
|         for item in model_knowledge:
 | |
|             if item.get("collection_name"):
 | |
|                 knowledge_files.append(
 | |
|                     {
 | |
|                         "id": item.get("collection_name"),
 | |
|                         "name": item.get("name"),
 | |
|                         "legacy": True,
 | |
|                     }
 | |
|                 )
 | |
|             elif item.get("collection_names"):
 | |
|                 knowledge_files.append(
 | |
|                     {
 | |
|                         "name": item.get("name"),
 | |
|                         "type": "collection",
 | |
|                         "collection_names": item.get("collection_names"),
 | |
|                         "legacy": True,
 | |
|                     }
 | |
|                 )
 | |
|             else:
 | |
|                 knowledge_files.append(item)
 | |
| 
 | |
|         files = form_data.get("files", [])
 | |
|         files.extend(knowledge_files)
 | |
|         form_data["files"] = files
 | |
| 
 | |
|     variables = form_data.pop("variables", None)
 | |
| 
 | |
|     features = form_data.pop("features", None)
 | |
|     if features:
 | |
|         if "web_search" in features and features["web_search"]:
 | |
|             form_data = await chat_web_search_handler(
 | |
|                 request, form_data, extra_params, user
 | |
|             )
 | |
| 
 | |
|         if "image_generation" in features and features["image_generation"]:
 | |
|             form_data = await chat_image_generation_handler(
 | |
|                 request, form_data, extra_params, user
 | |
|             )
 | |
| 
 | |
|         if "code_interpreter" in features and features["code_interpreter"]:
 | |
|             form_data["messages"] = add_or_update_user_message(
 | |
|                 DEFAULT_CODE_INTERPRETER_PROMPT, form_data["messages"]
 | |
|             )
 | |
| 
 | |
|     try:
 | |
|         form_data, flags = await chat_completion_filter_functions_handler(
 | |
|             request, form_data, model, extra_params
 | |
|         )
 | |
|     except Exception as e:
 | |
|         raise Exception(f"Error: {e}")
 | |
| 
 | |
|     tool_ids = form_data.pop("tool_ids", None)
 | |
|     files = form_data.pop("files", None)
 | |
|     # Remove files duplicates
 | |
|     if files:
 | |
|         files = list({json.dumps(f, sort_keys=True): f for f in files}.values())
 | |
| 
 | |
|     metadata = {
 | |
|         **metadata,
 | |
|         "tool_ids": tool_ids,
 | |
|         "files": files,
 | |
|     }
 | |
|     form_data["metadata"] = metadata
 | |
| 
 | |
|     tool_ids = metadata.get("tool_ids", None)
 | |
|     log.debug(f"{tool_ids=}")
 | |
| 
 | |
|     if tool_ids:
 | |
|         # If tool_ids field is present, then get the tools
 | |
|         tools = get_tools(
 | |
|             request,
 | |
|             tool_ids,
 | |
|             user,
 | |
|             {
 | |
|                 **extra_params,
 | |
|                 "__model__": models[task_model_id],
 | |
|                 "__messages__": form_data["messages"],
 | |
|                 "__files__": metadata.get("files", []),
 | |
|             },
 | |
|         )
 | |
|         log.info(f"{tools=}")
 | |
| 
 | |
|         if metadata.get("function_calling") == "native":
 | |
|             # If the function calling is native, then call the tools function calling handler
 | |
|             metadata["tools"] = tools
 | |
|             form_data["tools"] = [
 | |
|                 {"type": "function", "function": tool.get("spec", {})}
 | |
|                 for tool in tools.values()
 | |
|             ]
 | |
|         else:
 | |
|             # If the function calling is not native, then call the tools function calling handler
 | |
|             try:
 | |
|                 form_data, flags = await chat_completion_tools_handler(
 | |
|                     request, form_data, user, models, tools
 | |
|                 )
 | |
|                 sources.extend(flags.get("sources", []))
 | |
| 
 | |
|             except Exception as e:
 | |
|                 log.exception(e)
 | |
| 
 | |
|     try:
 | |
|         form_data, flags = await chat_completion_files_handler(request, form_data, user)
 | |
|         sources.extend(flags.get("sources", []))
 | |
|     except Exception as e:
 | |
|         log.exception(e)
 | |
| 
 | |
|     # If context is not empty, insert it into the messages
 | |
|     if len(sources) > 0:
 | |
|         context_string = ""
 | |
|         for source_idx, source in enumerate(sources):
 | |
|             source_id = source.get("source", {}).get("name", "")
 | |
| 
 | |
|             if "document" in source:
 | |
|                 for doc_idx, doc_context in enumerate(source["document"]):
 | |
|                     doc_metadata = source.get("metadata")
 | |
|                     doc_source_id = None
 | |
| 
 | |
|                     if doc_metadata:
 | |
|                         doc_source_id = doc_metadata[doc_idx].get("source", source_id)
 | |
| 
 | |
|                     if source_id:
 | |
|                         context_string += f"<source><source_id>{doc_source_id if doc_source_id is not None else source_id}</source_id><source_context>{doc_context}</source_context></source>\n"
 | |
|                     else:
 | |
|                         # If there is no source_id, then do not include the source_id tag
 | |
|                         context_string += f"<source><source_context>{doc_context}</source_context></source>\n"
 | |
| 
 | |
|         context_string = context_string.strip()
 | |
|         prompt = get_last_user_message(form_data["messages"])
 | |
| 
 | |
|         if prompt is None:
 | |
|             raise Exception("No user message found")
 | |
|         if (
 | |
|             request.app.state.config.RELEVANCE_THRESHOLD == 0
 | |
|             and context_string.strip() == ""
 | |
|         ):
 | |
|             log.debug(
 | |
|                 f"With a 0 relevancy threshold for RAG, the context cannot be empty"
 | |
|             )
 | |
| 
 | |
|         # Workaround for Ollama 2.0+ system prompt issue
 | |
|         # TODO: replace with add_or_update_system_message
 | |
|         if model["owned_by"] == "ollama":
 | |
|             form_data["messages"] = prepend_to_first_user_message_content(
 | |
|                 rag_template(
 | |
|                     request.app.state.config.RAG_TEMPLATE, context_string, prompt
 | |
|                 ),
 | |
|                 form_data["messages"],
 | |
|             )
 | |
|         else:
 | |
|             form_data["messages"] = add_or_update_system_message(
 | |
|                 rag_template(
 | |
|                     request.app.state.config.RAG_TEMPLATE, context_string, prompt
 | |
|                 ),
 | |
|                 form_data["messages"],
 | |
|             )
 | |
| 
 | |
|     # If there are citations, add them to the data_items
 | |
|     sources = [source for source in sources if source.get("source", {}).get("name", "")]
 | |
| 
 | |
|     if len(sources) > 0:
 | |
|         events.append({"sources": sources})
 | |
| 
 | |
|     if model_knowledge:
 | |
|         await event_emitter(
 | |
|             {
 | |
|                 "type": "status",
 | |
|                 "data": {
 | |
|                     "action": "knowledge_search",
 | |
|                     "query": user_message,
 | |
|                     "done": True,
 | |
|                     "hidden": True,
 | |
|                 },
 | |
|             }
 | |
|         )
 | |
| 
 | |
|     return form_data, metadata, events
 | |
| 
 | |
| 
 | |
| async def process_chat_response(
 | |
|     request, response, form_data, user, events, metadata, tasks
 | |
| ):
 | |
|     async def background_tasks_handler():
 | |
|         message_map = Chats.get_messages_by_chat_id(metadata["chat_id"])
 | |
|         message = message_map.get(metadata["message_id"]) if message_map else None
 | |
| 
 | |
|         if message:
 | |
|             messages = get_message_list(message_map, message.get("id"))
 | |
| 
 | |
|             if tasks and messages:
 | |
|                 if TASKS.TITLE_GENERATION in tasks:
 | |
|                     if tasks[TASKS.TITLE_GENERATION]:
 | |
|                         res = await generate_title(
 | |
|                             request,
 | |
|                             {
 | |
|                                 "model": message["model"],
 | |
|                                 "messages": messages,
 | |
|                                 "chat_id": metadata["chat_id"],
 | |
|                             },
 | |
|                             user,
 | |
|                         )
 | |
| 
 | |
|                         if res and isinstance(res, dict):
 | |
|                             if len(res.get("choices", [])) == 1:
 | |
|                                 title_string = (
 | |
|                                     res.get("choices", [])[0]
 | |
|                                     .get("message", {})
 | |
|                                     .get("content", message.get("content", "New Chat"))
 | |
|                                 )
 | |
|                             else:
 | |
|                                 title_string = ""
 | |
| 
 | |
|                             title_string = title_string[
 | |
|                                 title_string.find("{") : title_string.rfind("}") + 1
 | |
|                             ]
 | |
| 
 | |
|                             try:
 | |
|                                 title = json.loads(title_string).get(
 | |
|                                     "title", "New Chat"
 | |
|                                 )
 | |
|                             except Exception as e:
 | |
|                                 title = ""
 | |
| 
 | |
|                             if not title:
 | |
|                                 title = messages[0].get("content", "New Chat")
 | |
| 
 | |
|                             Chats.update_chat_title_by_id(metadata["chat_id"], title)
 | |
| 
 | |
|                             await event_emitter(
 | |
|                                 {
 | |
|                                     "type": "chat:title",
 | |
|                                     "data": title,
 | |
|                                 }
 | |
|                             )
 | |
|                     elif len(messages) == 2:
 | |
|                         title = messages[0].get("content", "New Chat")
 | |
| 
 | |
|                         Chats.update_chat_title_by_id(metadata["chat_id"], title)
 | |
| 
 | |
|                         await event_emitter(
 | |
|                             {
 | |
|                                 "type": "chat:title",
 | |
|                                 "data": message.get("content", "New Chat"),
 | |
|                             }
 | |
|                         )
 | |
| 
 | |
|                 if TASKS.TAGS_GENERATION in tasks and tasks[TASKS.TAGS_GENERATION]:
 | |
|                     res = await generate_chat_tags(
 | |
|                         request,
 | |
|                         {
 | |
|                             "model": message["model"],
 | |
|                             "messages": messages,
 | |
|                             "chat_id": metadata["chat_id"],
 | |
|                         },
 | |
|                         user,
 | |
|                     )
 | |
| 
 | |
|                     if res and isinstance(res, dict):
 | |
|                         if len(res.get("choices", [])) == 1:
 | |
|                             tags_string = (
 | |
|                                 res.get("choices", [])[0]
 | |
|                                 .get("message", {})
 | |
|                                 .get("content", "")
 | |
|                             )
 | |
|                         else:
 | |
|                             tags_string = ""
 | |
| 
 | |
|                         tags_string = tags_string[
 | |
|                             tags_string.find("{") : tags_string.rfind("}") + 1
 | |
|                         ]
 | |
| 
 | |
|                         try:
 | |
|                             tags = json.loads(tags_string).get("tags", [])
 | |
|                             Chats.update_chat_tags_by_id(
 | |
|                                 metadata["chat_id"], tags, user
 | |
|                             )
 | |
| 
 | |
|                             await event_emitter(
 | |
|                                 {
 | |
|                                     "type": "chat:tags",
 | |
|                                     "data": tags,
 | |
|                                 }
 | |
|                             )
 | |
|                         except Exception as e:
 | |
|                             pass
 | |
| 
 | |
|     event_emitter = None
 | |
|     event_caller = None
 | |
|     if (
 | |
|         "session_id" in metadata
 | |
|         and metadata["session_id"]
 | |
|         and "chat_id" in metadata
 | |
|         and metadata["chat_id"]
 | |
|         and "message_id" in metadata
 | |
|         and metadata["message_id"]
 | |
|     ):
 | |
|         event_emitter = get_event_emitter(metadata)
 | |
|         event_caller = get_event_call(metadata)
 | |
| 
 | |
|     # Non-streaming response
 | |
|     if not isinstance(response, StreamingResponse):
 | |
|         if event_emitter:
 | |
|             if "selected_model_id" in response:
 | |
|                 Chats.upsert_message_to_chat_by_id_and_message_id(
 | |
|                     metadata["chat_id"],
 | |
|                     metadata["message_id"],
 | |
|                     {
 | |
|                         "selectedModelId": response["selected_model_id"],
 | |
|                     },
 | |
|                 )
 | |
| 
 | |
|             if response.get("choices", [])[0].get("message", {}).get("content"):
 | |
|                 content = response["choices"][0]["message"]["content"]
 | |
| 
 | |
|                 if content:
 | |
| 
 | |
|                     await event_emitter(
 | |
|                         {
 | |
|                             "type": "chat:completion",
 | |
|                             "data": response,
 | |
|                         }
 | |
|                     )
 | |
| 
 | |
|                     title = Chats.get_chat_title_by_id(metadata["chat_id"])
 | |
| 
 | |
|                     await event_emitter(
 | |
|                         {
 | |
|                             "type": "chat:completion",
 | |
|                             "data": {
 | |
|                                 "done": True,
 | |
|                                 "content": content,
 | |
|                                 "title": title,
 | |
|                             },
 | |
|                         }
 | |
|                     )
 | |
| 
 | |
|                     # Save message in the database
 | |
|                     Chats.upsert_message_to_chat_by_id_and_message_id(
 | |
|                         metadata["chat_id"],
 | |
|                         metadata["message_id"],
 | |
|                         {
 | |
|                             "content": content,
 | |
|                         },
 | |
|                     )
 | |
| 
 | |
|                     # Send a webhook notification if the user is not active
 | |
|                     if get_active_status_by_user_id(user.id) is None:
 | |
|                         webhook_url = Users.get_user_webhook_url_by_id(user.id)
 | |
|                         if webhook_url:
 | |
|                             post_webhook(
 | |
|                                 webhook_url,
 | |
|                                 f"{title} - {request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}\n\n{content}",
 | |
|                                 {
 | |
|                                     "action": "chat",
 | |
|                                     "message": content,
 | |
|                                     "title": title,
 | |
|                                     "url": f"{request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}",
 | |
|                                 },
 | |
|                             )
 | |
| 
 | |
|                     await background_tasks_handler()
 | |
| 
 | |
|             return response
 | |
|         else:
 | |
|             return response
 | |
| 
 | |
|     # Non standard response
 | |
|     if not any(
 | |
|         content_type in response.headers["Content-Type"]
 | |
|         for content_type in ["text/event-stream", "application/x-ndjson"]
 | |
|     ):
 | |
|         return response
 | |
| 
 | |
|     # Streaming response
 | |
|     if event_emitter and event_caller:
 | |
|         task_id = str(uuid4())  # Create a unique task ID.
 | |
|         model_id = form_data.get("model", "")
 | |
| 
 | |
|         Chats.upsert_message_to_chat_by_id_and_message_id(
 | |
|             metadata["chat_id"],
 | |
|             metadata["message_id"],
 | |
|             {
 | |
|                 "model": model_id,
 | |
|             },
 | |
|         )
 | |
| 
 | |
|         # Handle as a background task
 | |
|         async def post_response_handler(response, events):
 | |
|             def serialize_content_blocks(content_blocks, raw=False):
 | |
|                 content = ""
 | |
| 
 | |
|                 for block in content_blocks:
 | |
|                     if block["type"] == "text":
 | |
|                         content = f"{content}{block['content'].strip()}\n"
 | |
|                     elif block["type"] == "tool_calls":
 | |
|                         attributes = block.get("attributes", {})
 | |
| 
 | |
|                         block_content = block.get("content", [])
 | |
|                         results = block.get("results", [])
 | |
| 
 | |
|                         if results:
 | |
| 
 | |
|                             result_display_content = ""
 | |
| 
 | |
|                             for result in results:
 | |
|                                 tool_call_id = result.get("tool_call_id", "")
 | |
|                                 tool_name = ""
 | |
| 
 | |
|                                 for tool_call in block_content:
 | |
|                                     if tool_call.get("id", "") == tool_call_id:
 | |
|                                         tool_name = tool_call.get("function", {}).get(
 | |
|                                             "name", ""
 | |
|                                         )
 | |
|                                         break
 | |
| 
 | |
|                                 result_display_content = f"{result_display_content}\n> {tool_name}: {result.get('content', '')}"
 | |
| 
 | |
|                             if not raw:
 | |
|                                 content = f'{content}\n<details type="tool_calls" done="true" content="{html.escape(json.dumps(block_content))}" results="{html.escape(json.dumps(results))}">\n<summary>Tool Executed</summary>\n{result_display_content}\n</details>\n'
 | |
|                         else:
 | |
|                             tool_calls_display_content = ""
 | |
| 
 | |
|                             for tool_call in block_content:
 | |
|                                 tool_calls_display_content = f"{tool_calls_display_content}\n> Executing {tool_call.get('function', {}).get('name', '')}"
 | |
| 
 | |
|                             if not raw:
 | |
|                                 content = f'{content}\n<details type="tool_calls" done="false" content="{html.escape(json.dumps(block_content))}">\n<summary>Tool Executing...</summary>\n{tool_calls_display_content}\n</details>\n'
 | |
| 
 | |
|                     elif block["type"] == "reasoning":
 | |
|                         reasoning_display_content = "\n".join(
 | |
|                             (f"> {line}" if not line.startswith(">") else line)
 | |
|                             for line in block["content"].splitlines()
 | |
|                         )
 | |
| 
 | |
|                         reasoning_duration = block.get("duration", None)
 | |
| 
 | |
|                         if reasoning_duration:
 | |
|                             if raw:
 | |
|                                 content = f'{content}\n<{block["tag"]}>{block["content"]}</{block["tag"]}>\n'
 | |
|                             else:
 | |
|                                 content = f'{content}\n<details type="reasoning" done="true" duration="{reasoning_duration}">\n<summary>Thought for {reasoning_duration} seconds</summary>\n{reasoning_display_content}\n</details>\n'
 | |
|                         else:
 | |
|                             if raw:
 | |
|                                 content = f'{content}\n<{block["tag"]}>{block["content"]}</{block["tag"]}>\n'
 | |
|                             else:
 | |
|                                 content = f'{content}\n<details type="reasoning" done="false">\n<summary>Thinking…</summary>\n{reasoning_display_content}\n</details>\n'
 | |
| 
 | |
|                     elif block["type"] == "code_interpreter":
 | |
|                         attributes = block.get("attributes", {})
 | |
|                         output = block.get("output", None)
 | |
|                         lang = attributes.get("lang", "")
 | |
| 
 | |
|                         if output:
 | |
|                             output = html.escape(json.dumps(output))
 | |
| 
 | |
|                             if raw:
 | |
|                                 content = f'{content}\n<code_interpreter type="code" lang="{lang}">\n{block["content"]}\n</code_interpreter>\n```output\n{output}\n```\n'
 | |
|                             else:
 | |
|                                 content = f'{content}\n<details type="code_interpreter" done="true" output="{output}">\n<summary>Analyzed</summary>\n```{lang}\n{block["content"]}\n```\n</details>\n'
 | |
|                         else:
 | |
|                             if raw:
 | |
|                                 content = f'{content}\n<code_interpreter type="code" lang="{lang}">\n{block["content"]}\n</code_interpreter>\n'
 | |
|                             else:
 | |
|                                 content = f'{content}\n<details type="code_interpreter" done="false">\n<summary>Analyzing...</summary>\n```{lang}\n{block["content"]}\n```\n</details>\n'
 | |
| 
 | |
|                     else:
 | |
|                         block_content = str(block["content"]).strip()
 | |
|                         content = f"{content}{block['type']}: {block_content}\n"
 | |
| 
 | |
|                 return content
 | |
| 
 | |
|             def tag_content_handler(content_type, tags, content, content_blocks):
 | |
|                 end_flag = False
 | |
| 
 | |
|                 def extract_attributes(tag_content):
 | |
|                     """Extract attributes from a tag if they exist."""
 | |
|                     attributes = {}
 | |
|                     # Match attributes in the format: key="value" (ignores single quotes for simplicity)
 | |
|                     matches = re.findall(r'(\w+)\s*=\s*"([^"]+)"', tag_content)
 | |
|                     for key, value in matches:
 | |
|                         attributes[key] = value
 | |
|                     return attributes
 | |
| 
 | |
|                 if content_blocks[-1]["type"] == "text":
 | |
|                     for tag in tags:
 | |
|                         # Match start tag e.g., <tag> or <tag attr="value">
 | |
|                         start_tag_pattern = rf"<{tag}(.*?)>"
 | |
|                         match = re.search(start_tag_pattern, content)
 | |
|                         if match:
 | |
|                             # Extract attributes in the tag (if present)
 | |
|                             attributes = extract_attributes(match.group(1))
 | |
|                             # Remove the start tag from the currently handling text block
 | |
|                             content_blocks[-1]["content"] = content_blocks[-1][
 | |
|                                 "content"
 | |
|                             ].replace(match.group(0), "")
 | |
|                             if not content_blocks[-1]["content"]:
 | |
|                                 content_blocks.pop()
 | |
|                             # Append the new block
 | |
|                             content_blocks.append(
 | |
|                                 {
 | |
|                                     "type": content_type,
 | |
|                                     "tag": tag,
 | |
|                                     "attributes": attributes,
 | |
|                                     "content": "",
 | |
|                                     "started_at": time.time(),
 | |
|                                 }
 | |
|                             )
 | |
|                             break
 | |
|                 elif content_blocks[-1]["type"] == content_type:
 | |
|                     tag = content_blocks[-1]["tag"]
 | |
|                     # Match end tag e.g., </tag>
 | |
|                     end_tag_pattern = rf"</{tag}>"
 | |
|                     if re.search(end_tag_pattern, content):
 | |
|                         block_content = content_blocks[-1]["content"]
 | |
|                         # Strip start and end tags from the content
 | |
|                         start_tag_pattern = rf"<{tag}(.*?)>"
 | |
|                         block_content = re.sub(
 | |
|                             start_tag_pattern, "", block_content
 | |
|                         ).strip()
 | |
|                         block_content = re.sub(
 | |
|                             end_tag_pattern, "", block_content
 | |
|                         ).strip()
 | |
|                         if block_content:
 | |
|                             end_flag = True
 | |
|                             content_blocks[-1]["content"] = block_content
 | |
|                             content_blocks[-1]["ended_at"] = time.time()
 | |
|                             content_blocks[-1]["duration"] = int(
 | |
|                                 content_blocks[-1]["ended_at"]
 | |
|                                 - content_blocks[-1]["started_at"]
 | |
|                             )
 | |
|                             # Reset the content_blocks by appending a new text block
 | |
|                             content_blocks.append(
 | |
|                                 {
 | |
|                                     "type": "text",
 | |
|                                     "content": "",
 | |
|                                 }
 | |
|                             )
 | |
|                             # Clean processed content
 | |
|                             content = re.sub(
 | |
|                                 rf"<{tag}(.*?)>(.|\n)*?</{tag}>",
 | |
|                                 "",
 | |
|                                 content,
 | |
|                                 flags=re.DOTALL,
 | |
|                             )
 | |
|                         else:
 | |
|                             # Remove the block if content is empty
 | |
|                             content_blocks.pop()
 | |
|                 return content, content_blocks, end_flag
 | |
| 
 | |
|             message = Chats.get_message_by_id_and_message_id(
 | |
|                 metadata["chat_id"], metadata["message_id"]
 | |
|             )
 | |
| 
 | |
|             tool_calls = []
 | |
|             content = message.get("content", "") if message else ""
 | |
|             content_blocks = [
 | |
|                 {
 | |
|                     "type": "text",
 | |
|                     "content": content,
 | |
|                 }
 | |
|             ]
 | |
| 
 | |
|             # We might want to disable this by default
 | |
|             DETECT_REASONING = True
 | |
|             DETECT_CODE_INTERPRETER = metadata.get("features", {}).get(
 | |
|                 "code_interpreter", False
 | |
|             )
 | |
| 
 | |
|             reasoning_tags = ["think", "reason", "reasoning", "thought", "Thought"]
 | |
|             code_interpreter_tags = ["code_interpreter"]
 | |
| 
 | |
|             try:
 | |
|                 for event in events:
 | |
|                     await event_emitter(
 | |
|                         {
 | |
|                             "type": "chat:completion",
 | |
|                             "data": event,
 | |
|                         }
 | |
|                     )
 | |
| 
 | |
|                     # Save message in the database
 | |
|                     Chats.upsert_message_to_chat_by_id_and_message_id(
 | |
|                         metadata["chat_id"],
 | |
|                         metadata["message_id"],
 | |
|                         {
 | |
|                             **event,
 | |
|                         },
 | |
|                     )
 | |
| 
 | |
|                 async def stream_body_handler(response):
 | |
|                     nonlocal content
 | |
|                     nonlocal content_blocks
 | |
| 
 | |
|                     response_tool_calls = []
 | |
| 
 | |
|                     async for line in response.body_iterator:
 | |
|                         line = line.decode("utf-8") if isinstance(line, bytes) else line
 | |
|                         data = line
 | |
| 
 | |
|                         # Skip empty lines
 | |
|                         if not data.strip():
 | |
|                             continue
 | |
| 
 | |
|                         # "data:" is the prefix for each event
 | |
|                         if not data.startswith("data:"):
 | |
|                             continue
 | |
| 
 | |
|                         # Remove the prefix
 | |
|                         data = data[len("data:") :].strip()
 | |
| 
 | |
|                         try:
 | |
|                             data = json.loads(data)
 | |
| 
 | |
|                             if "selected_model_id" in data:
 | |
|                                 model_id = data["selected_model_id"]
 | |
|                                 Chats.upsert_message_to_chat_by_id_and_message_id(
 | |
|                                     metadata["chat_id"],
 | |
|                                     metadata["message_id"],
 | |
|                                     {
 | |
|                                         "selectedModelId": model_id,
 | |
|                                     },
 | |
|                                 )
 | |
|                             else:
 | |
|                                 choices = data.get("choices", [])
 | |
|                                 if not choices:
 | |
|                                     continue
 | |
| 
 | |
|                                 delta = choices[0].get("delta", {})
 | |
|                                 delta_tool_calls = delta.get("tool_calls", None)
 | |
| 
 | |
|                                 if delta_tool_calls:
 | |
|                                     for delta_tool_call in delta_tool_calls:
 | |
|                                         tool_call_index = delta_tool_call.get("index")
 | |
| 
 | |
|                                         if tool_call_index is not None:
 | |
|                                             if (
 | |
|                                                 len(response_tool_calls)
 | |
|                                                 <= tool_call_index
 | |
|                                             ):
 | |
|                                                 response_tool_calls.append(
 | |
|                                                     delta_tool_call
 | |
|                                                 )
 | |
|                                             else:
 | |
|                                                 delta_name = delta_tool_call.get(
 | |
|                                                     "function", {}
 | |
|                                                 ).get("name")
 | |
|                                                 delta_arguments = delta_tool_call.get(
 | |
|                                                     "function", {}
 | |
|                                                 ).get("arguments")
 | |
| 
 | |
|                                                 if delta_name:
 | |
|                                                     response_tool_calls[
 | |
|                                                         tool_call_index
 | |
|                                                     ]["function"]["name"] += delta_name
 | |
| 
 | |
|                                                 if delta_arguments:
 | |
|                                                     response_tool_calls[
 | |
|                                                         tool_call_index
 | |
|                                                     ]["function"][
 | |
|                                                         "arguments"
 | |
|                                                     ] += delta_arguments
 | |
| 
 | |
|                                 value = delta.get("content")
 | |
| 
 | |
|                                 if value:
 | |
|                                     content = f"{content}{value}"
 | |
|                                     content_blocks[-1]["content"] = (
 | |
|                                         content_blocks[-1]["content"] + value
 | |
|                                     )
 | |
| 
 | |
|                                     if DETECT_REASONING:
 | |
|                                         content, content_blocks, _ = (
 | |
|                                             tag_content_handler(
 | |
|                                                 "reasoning",
 | |
|                                                 reasoning_tags,
 | |
|                                                 content,
 | |
|                                                 content_blocks,
 | |
|                                             )
 | |
|                                         )
 | |
| 
 | |
|                                     if DETECT_CODE_INTERPRETER:
 | |
|                                         content, content_blocks, end = (
 | |
|                                             tag_content_handler(
 | |
|                                                 "code_interpreter",
 | |
|                                                 code_interpreter_tags,
 | |
|                                                 content,
 | |
|                                                 content_blocks,
 | |
|                                             )
 | |
|                                         )
 | |
| 
 | |
|                                         if end:
 | |
|                                             break
 | |
| 
 | |
|                                     if ENABLE_REALTIME_CHAT_SAVE:
 | |
|                                         # Save message in the database
 | |
|                                         Chats.upsert_message_to_chat_by_id_and_message_id(
 | |
|                                             metadata["chat_id"],
 | |
|                                             metadata["message_id"],
 | |
|                                             {
 | |
|                                                 "content": serialize_content_blocks(
 | |
|                                                     content_blocks
 | |
|                                                 ),
 | |
|                                             },
 | |
|                                         )
 | |
|                                     else:
 | |
|                                         data = {
 | |
|                                             "content": serialize_content_blocks(
 | |
|                                                 content_blocks
 | |
|                                             ),
 | |
|                                         }
 | |
| 
 | |
|                             await event_emitter(
 | |
|                                 {
 | |
|                                     "type": "chat:completion",
 | |
|                                     "data": data,
 | |
|                                 }
 | |
|                             )
 | |
|                         except Exception as e:
 | |
|                             done = "data: [DONE]" in line
 | |
|                             if done:
 | |
|                                 pass
 | |
|                             else:
 | |
|                                 log.debug("Error: ", e)
 | |
|                                 continue
 | |
| 
 | |
|                     # Clean up the last text block
 | |
|                     if content_blocks[-1]["type"] == "text":
 | |
|                         content_blocks[-1]["content"] = content_blocks[-1][
 | |
|                             "content"
 | |
|                         ].strip()
 | |
| 
 | |
|                         if not content_blocks[-1]["content"]:
 | |
|                             content_blocks.pop()
 | |
| 
 | |
|                     if response_tool_calls:
 | |
|                         tool_calls.append(response_tool_calls)
 | |
| 
 | |
|                     if response.background:
 | |
|                         await response.background()
 | |
| 
 | |
|                 await stream_body_handler(response)
 | |
| 
 | |
|                 MAX_TOOL_CALL_RETRIES = 5
 | |
|                 tool_call_retries = 0
 | |
| 
 | |
|                 while len(tool_calls) > 0 and tool_call_retries < MAX_TOOL_CALL_RETRIES:
 | |
|                     tool_call_retries += 1
 | |
| 
 | |
|                     response_tool_calls = tool_calls.pop(0)
 | |
| 
 | |
|                     content_blocks.append(
 | |
|                         {
 | |
|                             "type": "tool_calls",
 | |
|                             "content": response_tool_calls,
 | |
|                         }
 | |
|                     )
 | |
| 
 | |
|                     await event_emitter(
 | |
|                         {
 | |
|                             "type": "chat:completion",
 | |
|                             "data": {
 | |
|                                 "content": serialize_content_blocks(content_blocks),
 | |
|                             },
 | |
|                         }
 | |
|                     )
 | |
| 
 | |
|                     tools = metadata.get("tools", {})
 | |
| 
 | |
|                     results = []
 | |
|                     for tool_call in response_tool_calls:
 | |
|                         print("\n\n" + str(tool_call) + "\n\n")
 | |
|                         tool_call_id = tool_call.get("id", "")
 | |
|                         tool_name = tool_call.get("function", {}).get("name", "")
 | |
| 
 | |
|                         tool_function_params = {}
 | |
|                         try:
 | |
|                             # json.loads cannot be used because some models do not produce valid JSON
 | |
|                             tool_function_params = ast.literal_eval(
 | |
|                                 tool_call.get("function", {}).get("arguments", "{}")
 | |
|                             )
 | |
|                         except Exception as e:
 | |
|                             log.debug(e)
 | |
| 
 | |
|                         tool_result = None
 | |
| 
 | |
|                         if tool_name in tools:
 | |
|                             tool = tools[tool_name]
 | |
|                             spec = tool.get("spec", {})
 | |
| 
 | |
|                             try:
 | |
|                                 required_params = spec.get("parameters", {}).get(
 | |
|                                     "required", []
 | |
|                                 )
 | |
|                                 tool_function = tool["callable"]
 | |
|                                 tool_function_params = {
 | |
|                                     k: v
 | |
|                                     for k, v in tool_function_params.items()
 | |
|                                     if k in required_params
 | |
|                                 }
 | |
|                                 tool_result = await tool_function(
 | |
|                                     **tool_function_params
 | |
|                                 )
 | |
|                             except Exception as e:
 | |
|                                 tool_result = str(e)
 | |
| 
 | |
|                         results.append(
 | |
|                             {
 | |
|                                 "tool_call_id": tool_call_id,
 | |
|                                 "content": tool_result,
 | |
|                             }
 | |
|                         )
 | |
| 
 | |
|                     content_blocks[-1]["results"] = results
 | |
| 
 | |
|                     content_blocks.append(
 | |
|                         {
 | |
|                             "type": "text",
 | |
|                             "content": "",
 | |
|                         }
 | |
|                     )
 | |
| 
 | |
|                     await event_emitter(
 | |
|                         {
 | |
|                             "type": "chat:completion",
 | |
|                             "data": {
 | |
|                                 "content": serialize_content_blocks(content_blocks),
 | |
|                             },
 | |
|                         }
 | |
|                     )
 | |
| 
 | |
|                     try:
 | |
|                         res = await generate_chat_completion(
 | |
|                             request,
 | |
|                             {
 | |
|                                 "model": model_id,
 | |
|                                 "stream": True,
 | |
|                                 "messages": [
 | |
|                                     *form_data["messages"],
 | |
|                                     {
 | |
|                                         "role": "assistant",
 | |
|                                         "content": serialize_content_blocks(
 | |
|                                             content_blocks, raw=True
 | |
|                                         ),
 | |
|                                         "tool_calls": response_tool_calls,
 | |
|                                     },
 | |
|                                     *[
 | |
|                                         {
 | |
|                                             "role": "tool",
 | |
|                                             "tool_call_id": result["tool_call_id"],
 | |
|                                             "content": result["content"],
 | |
|                                         }
 | |
|                                         for result in results
 | |
|                                     ],
 | |
|                                 ],
 | |
|                             },
 | |
|                             user,
 | |
|                         )
 | |
| 
 | |
|                         if isinstance(res, StreamingResponse):
 | |
|                             await stream_body_handler(res)
 | |
|                         else:
 | |
|                             break
 | |
|                     except Exception as e:
 | |
|                         log.debug(e)
 | |
|                         break
 | |
| 
 | |
|                 if DETECT_CODE_INTERPRETER:
 | |
|                     MAX_RETRIES = 5
 | |
|                     retries = 0
 | |
| 
 | |
|                     while (
 | |
|                         content_blocks[-1]["type"] == "code_interpreter"
 | |
|                         and retries < MAX_RETRIES
 | |
|                     ):
 | |
|                         retries += 1
 | |
|                         log.debug(f"Attempt count: {retries}")
 | |
| 
 | |
|                         output = ""
 | |
|                         try:
 | |
|                             if content_blocks[-1]["attributes"].get("type") == "code":
 | |
|                                 output = await event_caller(
 | |
|                                     {
 | |
|                                         "type": "execute:python",
 | |
|                                         "data": {
 | |
|                                             "id": str(uuid4()),
 | |
|                                             "code": content_blocks[-1]["content"],
 | |
|                                         },
 | |
|                                     }
 | |
|                                 )
 | |
| 
 | |
|                                 if isinstance(output, dict):
 | |
|                                     stdout = output.get("stdout", "")
 | |
| 
 | |
|                                     if stdout:
 | |
|                                         stdoutLines = stdout.split("\n")
 | |
|                                         for idx, line in enumerate(stdoutLines):
 | |
|                                             if "data:image/png;base64" in line:
 | |
|                                                 id = str(uuid4())
 | |
| 
 | |
|                                                 # ensure the path exists
 | |
|                                                 os.makedirs(
 | |
|                                                     os.path.join(CACHE_DIR, "images"),
 | |
|                                                     exist_ok=True,
 | |
|                                                 )
 | |
| 
 | |
|                                                 image_path = os.path.join(
 | |
|                                                     CACHE_DIR,
 | |
|                                                     f"images/{id}.png",
 | |
|                                                 )
 | |
| 
 | |
|                                                 with open(image_path, "wb") as f:
 | |
|                                                     f.write(
 | |
|                                                         base64.b64decode(
 | |
|                                                             line.split(",")[1]
 | |
|                                                         )
 | |
|                                                     )
 | |
| 
 | |
|                                                 stdoutLines[idx] = (
 | |
|                                                     f""
 | |
|                                                 )
 | |
| 
 | |
|                                         output["stdout"] = "\n".join(stdoutLines)
 | |
|                         except Exception as e:
 | |
|                             output = str(e)
 | |
| 
 | |
|                         content_blocks[-1]["output"] = output
 | |
| 
 | |
|                         content_blocks.append(
 | |
|                             {
 | |
|                                 "type": "text",
 | |
|                                 "content": "",
 | |
|                             }
 | |
|                         )
 | |
| 
 | |
|                         await event_emitter(
 | |
|                             {
 | |
|                                 "type": "chat:completion",
 | |
|                                 "data": {
 | |
|                                     "content": serialize_content_blocks(content_blocks),
 | |
|                                 },
 | |
|                             }
 | |
|                         )
 | |
| 
 | |
|                         try:
 | |
|                             res = await generate_chat_completion(
 | |
|                                 request,
 | |
|                                 {
 | |
|                                     "model": model_id,
 | |
|                                     "stream": True,
 | |
|                                     "messages": [
 | |
|                                         *form_data["messages"],
 | |
|                                         {
 | |
|                                             "role": "assistant",
 | |
|                                             "content": serialize_content_blocks(
 | |
|                                                 content_blocks, raw=True
 | |
|                                             ),
 | |
|                                         },
 | |
|                                     ],
 | |
|                                 },
 | |
|                                 user,
 | |
|                             )
 | |
| 
 | |
|                             if isinstance(res, StreamingResponse):
 | |
|                                 await stream_body_handler(res)
 | |
|                             else:
 | |
|                                 break
 | |
|                         except Exception as e:
 | |
|                             log.debug(e)
 | |
|                             break
 | |
| 
 | |
|                 title = Chats.get_chat_title_by_id(metadata["chat_id"])
 | |
|                 data = {
 | |
|                     "done": True,
 | |
|                     "content": serialize_content_blocks(content_blocks),
 | |
|                     "title": title,
 | |
|                 }
 | |
| 
 | |
|                 if not ENABLE_REALTIME_CHAT_SAVE:
 | |
|                     # Save message in the database
 | |
|                     Chats.upsert_message_to_chat_by_id_and_message_id(
 | |
|                         metadata["chat_id"],
 | |
|                         metadata["message_id"],
 | |
|                         {
 | |
|                             "content": serialize_content_blocks(content_blocks),
 | |
|                         },
 | |
|                     )
 | |
| 
 | |
|                 # Send a webhook notification if the user is not active
 | |
|                 if get_active_status_by_user_id(user.id) is None:
 | |
|                     webhook_url = Users.get_user_webhook_url_by_id(user.id)
 | |
|                     if webhook_url:
 | |
|                         post_webhook(
 | |
|                             webhook_url,
 | |
|                             f"{title} - {request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}\n\n{content}",
 | |
|                             {
 | |
|                                 "action": "chat",
 | |
|                                 "message": content,
 | |
|                                 "title": title,
 | |
|                                 "url": f"{request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}",
 | |
|                             },
 | |
|                         )
 | |
| 
 | |
|                 await event_emitter(
 | |
|                     {
 | |
|                         "type": "chat:completion",
 | |
|                         "data": data,
 | |
|                     }
 | |
|                 )
 | |
| 
 | |
|                 await background_tasks_handler()
 | |
|             except asyncio.CancelledError:
 | |
|                 print("Task was cancelled!")
 | |
|                 await event_emitter({"type": "task-cancelled"})
 | |
| 
 | |
|                 if not ENABLE_REALTIME_CHAT_SAVE:
 | |
|                     # Save message in the database
 | |
|                     Chats.upsert_message_to_chat_by_id_and_message_id(
 | |
|                         metadata["chat_id"],
 | |
|                         metadata["message_id"],
 | |
|                         {
 | |
|                             "content": serialize_content_blocks(content_blocks),
 | |
|                         },
 | |
|                     )
 | |
| 
 | |
|             if response.background is not None:
 | |
|                 await response.background()
 | |
| 
 | |
|         # background_tasks.add_task(post_response_handler, response, events)
 | |
|         task_id, _ = create_task(post_response_handler(response, events))
 | |
|         return {"status": True, "task_id": task_id}
 | |
| 
 | |
|     else:
 | |
| 
 | |
|         # Fallback to the original response
 | |
|         async def stream_wrapper(original_generator, events):
 | |
|             def wrap_item(item):
 | |
|                 return f"data: {item}\n\n"
 | |
| 
 | |
|             for event in events:
 | |
|                 yield wrap_item(json.dumps(event))
 | |
| 
 | |
|             async for data in original_generator:
 | |
|                 yield data
 | |
| 
 | |
|         return StreamingResponse(
 | |
|             stream_wrapper(response.body_iterator, events),
 | |
|             headers=dict(response.headers),
 | |
|             background=response.background,
 | |
|         )
 |