| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  | from fastapi import APIRouter, Depends, HTTPException, Response, status, Request | 
					
						
							| 
									
										
										
										
											2024-12-12 11:52:46 +08:00
										 |  |  | from fastapi.responses import JSONResponse, RedirectResponse | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  | from pydantic import BaseModel | 
					
						
							|  |  |  | from typing import Optional | 
					
						
							| 
									
										
										
										
											2024-12-10 16:54:13 +08:00
										 |  |  | import logging | 
					
						
							| 
									
										
										
										
											2025-01-30 06:59:23 +08:00
										 |  |  | import re | 
					
						
							| 
									
										
										
										
											2024-12-10 16:54:13 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-13 12:22:17 +08:00
										 |  |  | from open_webui.utils.chat import generate_chat_completion | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  | from open_webui.utils.task import ( | 
					
						
							|  |  |  |     title_generation_template, | 
					
						
							| 
									
										
										
										
											2025-06-03 22:07:29 +08:00
										 |  |  |     follow_up_generation_template, | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  |     query_generation_template, | 
					
						
							| 
									
										
										
										
											2025-01-16 16:06:37 +08:00
										 |  |  |     image_prompt_generation_template, | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  |     autocomplete_generation_template, | 
					
						
							|  |  |  |     tags_generation_template, | 
					
						
							|  |  |  |     emoji_generation_template, | 
					
						
							|  |  |  |     moa_response_generation_template, | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | from open_webui.utils.auth import get_admin_user, get_verified_user | 
					
						
							| 
									
										
										
										
											2024-12-10 16:54:13 +08:00
										 |  |  | from open_webui.constants import TASKS | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-12 11:52:46 +08:00
										 |  |  | from open_webui.routers.pipelines import process_pipeline_inlet_filter | 
					
						
							| 
									
										
										
										
											2025-05-17 03:33:02 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-12 11:52:46 +08:00
										 |  |  | from open_webui.utils.task import get_task_model_id | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-10 16:54:13 +08:00
										 |  |  | from open_webui.config import ( | 
					
						
							| 
									
										
										
										
											2024-12-13 14:28:42 +08:00
										 |  |  |     DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE, | 
					
						
							| 
									
										
										
										
											2025-06-03 22:07:29 +08:00
										 |  |  |     DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE, | 
					
						
							| 
									
										
										
										
											2024-12-13 14:28:42 +08:00
										 |  |  |     DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE, | 
					
						
							| 
									
										
										
										
											2025-01-16 16:06:37 +08:00
										 |  |  |     DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE, | 
					
						
							| 
									
										
										
										
											2024-12-10 16:54:13 +08:00
										 |  |  |     DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE, | 
					
						
							|  |  |  |     DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE, | 
					
						
							| 
									
										
										
										
											2024-12-13 14:28:42 +08:00
										 |  |  |     DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE, | 
					
						
							|  |  |  |     DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE, | 
					
						
							| 
									
										
										
										
											2024-12-10 16:54:13 +08:00
										 |  |  | ) | 
					
						
							|  |  |  | from open_webui.env import SRC_LOG_LEVELS | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | log = logging.getLogger(__name__) | 
					
						
							|  |  |  | log.setLevel(SRC_LOG_LEVELS["MODELS"]) | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | router = APIRouter() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ################################## | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | # Task Endpoints | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | ################################## | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @router.get("/config") | 
					
						
							|  |  |  | async def get_task_config(request: Request, user=Depends(get_verified_user)): | 
					
						
							|  |  |  |     return { | 
					
						
							|  |  |  |         "TASK_MODEL": request.app.state.config.TASK_MODEL, | 
					
						
							|  |  |  |         "TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL, | 
					
						
							|  |  |  |         "TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, | 
					
						
							| 
									
										
										
										
											2025-01-16 16:06:37 +08:00
										 |  |  |         "IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE, | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  |         "ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION, | 
					
						
							|  |  |  |         "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH, | 
					
						
							|  |  |  |         "TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, | 
					
						
							| 
									
										
										
										
											2025-06-03 22:07:29 +08:00
										 |  |  |         "FOLLOW_UP_GENERATION_PROMPT_TEMPLATE": request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE, | 
					
						
							|  |  |  |         "ENABLE_FOLLOW_UP_GENERATION": request.app.state.config.ENABLE_FOLLOW_UP_GENERATION, | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  |         "ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION, | 
					
						
							| 
									
										
										
										
											2025-02-13 23:28:39 +08:00
										 |  |  |         "ENABLE_TITLE_GENERATION": request.app.state.config.ENABLE_TITLE_GENERATION, | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  |         "ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION, | 
					
						
							|  |  |  |         "ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION, | 
					
						
							|  |  |  |         "QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE, | 
					
						
							|  |  |  |         "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class TaskConfigForm(BaseModel): | 
					
						
							|  |  |  |     TASK_MODEL: Optional[str] | 
					
						
							|  |  |  |     TASK_MODEL_EXTERNAL: Optional[str] | 
					
						
							| 
									
										
										
										
											2025-02-13 23:28:39 +08:00
										 |  |  |     ENABLE_TITLE_GENERATION: bool | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  |     TITLE_GENERATION_PROMPT_TEMPLATE: str | 
					
						
							| 
									
										
										
										
											2025-01-16 16:06:37 +08:00
										 |  |  |     IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE: str | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  |     ENABLE_AUTOCOMPLETE_GENERATION: bool | 
					
						
							|  |  |  |     AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: int | 
					
						
							|  |  |  |     TAGS_GENERATION_PROMPT_TEMPLATE: str | 
					
						
							| 
									
										
										
										
											2025-06-03 22:07:29 +08:00
										 |  |  |     FOLLOW_UP_GENERATION_PROMPT_TEMPLATE: str | 
					
						
							|  |  |  |     ENABLE_FOLLOW_UP_GENERATION: bool | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  |     ENABLE_TAGS_GENERATION: bool | 
					
						
							|  |  |  |     ENABLE_SEARCH_QUERY_GENERATION: bool | 
					
						
							|  |  |  |     ENABLE_RETRIEVAL_QUERY_GENERATION: bool | 
					
						
							|  |  |  |     QUERY_GENERATION_PROMPT_TEMPLATE: str | 
					
						
							|  |  |  |     TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @router.post("/config/update") | 
					
						
							|  |  |  | async def update_task_config( | 
					
						
							|  |  |  |     request: Request, form_data: TaskConfigForm, user=Depends(get_admin_user) | 
					
						
							|  |  |  | ): | 
					
						
							|  |  |  |     request.app.state.config.TASK_MODEL = form_data.TASK_MODEL | 
					
						
							|  |  |  |     request.app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL | 
					
						
							| 
									
										
										
										
											2025-02-13 23:28:39 +08:00
										 |  |  |     request.app.state.config.ENABLE_TITLE_GENERATION = form_data.ENABLE_TITLE_GENERATION | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  |     request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = ( | 
					
						
							|  |  |  |         form_data.TITLE_GENERATION_PROMPT_TEMPLATE | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-06-03 22:07:29 +08:00
										 |  |  |     request.app.state.config.ENABLE_FOLLOW_UP_GENERATION = ( | 
					
						
							|  |  |  |         form_data.ENABLE_FOLLOW_UP_GENERATION | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = ( | 
					
						
							|  |  |  |         form_data.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-01 15:50:58 +08:00
										 |  |  |     request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = ( | 
					
						
							|  |  |  |         form_data.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  |     request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = ( | 
					
						
							|  |  |  |         form_data.ENABLE_AUTOCOMPLETE_GENERATION | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = ( | 
					
						
							|  |  |  |         form_data.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = ( | 
					
						
							|  |  |  |         form_data.TAGS_GENERATION_PROMPT_TEMPLATE | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     request.app.state.config.ENABLE_TAGS_GENERATION = form_data.ENABLE_TAGS_GENERATION | 
					
						
							|  |  |  |     request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION = ( | 
					
						
							|  |  |  |         form_data.ENABLE_SEARCH_QUERY_GENERATION | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = ( | 
					
						
							|  |  |  |         form_data.ENABLE_RETRIEVAL_QUERY_GENERATION | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = ( | 
					
						
							|  |  |  |         form_data.QUERY_GENERATION_PROMPT_TEMPLATE | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = ( | 
					
						
							|  |  |  |         form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return { | 
					
						
							|  |  |  |         "TASK_MODEL": request.app.state.config.TASK_MODEL, | 
					
						
							|  |  |  |         "TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL, | 
					
						
							| 
									
										
										
										
											2025-02-14 14:54:45 +08:00
										 |  |  |         "ENABLE_TITLE_GENERATION": request.app.state.config.ENABLE_TITLE_GENERATION, | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  |         "TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, | 
					
						
							| 
									
										
										
										
											2025-01-16 16:06:37 +08:00
										 |  |  |         "IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE, | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  |         "ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION, | 
					
						
							|  |  |  |         "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH, | 
					
						
							|  |  |  |         "TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, | 
					
						
							|  |  |  |         "ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION, | 
					
						
							| 
									
										
										
										
											2025-06-03 22:07:29 +08:00
										 |  |  |         "ENABLE_FOLLOW_UP_GENERATION": request.app.state.config.ENABLE_FOLLOW_UP_GENERATION, | 
					
						
							|  |  |  |         "FOLLOW_UP_GENERATION_PROMPT_TEMPLATE": request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE, | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  |         "ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION, | 
					
						
							|  |  |  |         "ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION, | 
					
						
							|  |  |  |         "QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE, | 
					
						
							|  |  |  |         "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @router.post("/title/completions") | 
					
						
							|  |  |  | async def generate_title( | 
					
						
							|  |  |  |     request: Request, form_data: dict, user=Depends(get_verified_user) | 
					
						
							|  |  |  | ): | 
					
						
							| 
									
										
										
										
											2025-02-13 23:28:39 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     if not request.app.state.config.ENABLE_TITLE_GENERATION: | 
					
						
							|  |  |  |         return JSONResponse( | 
					
						
							|  |  |  |             status_code=status.HTTP_200_OK, | 
					
						
							|  |  |  |             content={"detail": "Title generation is disabled"}, | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2025-02-14 14:54:45 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-13 15:26:47 +08:00
										 |  |  |     if getattr(request.state, "direct", False) and hasattr(request.state, "model"): | 
					
						
							| 
									
										
										
										
											2025-02-13 14:56:33 +08:00
										 |  |  |         models = { | 
					
						
							|  |  |  |             request.state.model["id"]: request.state.model, | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         models = request.app.state.MODELS | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     model_id = form_data["model"] | 
					
						
							|  |  |  |     if model_id not in models: | 
					
						
							|  |  |  |         raise HTTPException( | 
					
						
							|  |  |  |             status_code=status.HTTP_404_NOT_FOUND, | 
					
						
							|  |  |  |             detail="Model not found", | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Check if the user has a custom task model | 
					
						
							|  |  |  |     # If the user has a custom task model, use that model | 
					
						
							|  |  |  |     task_model_id = get_task_model_id( | 
					
						
							|  |  |  |         model_id, | 
					
						
							|  |  |  |         request.app.state.config.TASK_MODEL, | 
					
						
							|  |  |  |         request.app.state.config.TASK_MODEL_EXTERNAL, | 
					
						
							|  |  |  |         models, | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     log.debug( | 
					
						
							|  |  |  |         f"generating chat title using model {task_model_id} for user {user.email} " | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "": | 
					
						
							|  |  |  |         template = request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE | 
					
						
							|  |  |  |     else: | 
					
						
							| 
									
										
										
										
											2024-12-13 14:28:42 +08:00
										 |  |  |         template = DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     content = title_generation_template( | 
					
						
							|  |  |  |         template, | 
					
						
							| 
									
										
										
										
											2025-05-08 14:51:50 +08:00
										 |  |  |         form_data["messages"], | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  |         { | 
					
						
							|  |  |  |             "name": user.name, | 
					
						
							|  |  |  |             "location": user.info.get("location") if user.info else None, | 
					
						
							|  |  |  |         }, | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-05-19 03:44:51 +08:00
										 |  |  |     max_tokens = ( | 
					
						
							|  |  |  |         models[task_model_id].get("info", {}).get("params", {}).get("max_tokens", 1000) | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  |     payload = { | 
					
						
							|  |  |  |         "model": task_model_id, | 
					
						
							|  |  |  |         "messages": [{"role": "user", "content": content}], | 
					
						
							|  |  |  |         "stream": False, | 
					
						
							|  |  |  |         **( | 
					
						
							| 
									
										
										
										
											2025-05-19 03:44:51 +08:00
										 |  |  |             {"max_tokens": max_tokens} | 
					
						
							| 
									
										
										
										
											2025-02-16 08:41:41 +08:00
										 |  |  |             if models[task_model_id].get("owned_by") == "ollama" | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  |             else { | 
					
						
							| 
									
										
										
										
											2025-05-19 03:44:51 +08:00
										 |  |  |                 "max_completion_tokens": max_tokens, | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  |             } | 
					
						
							|  |  |  |         ), | 
					
						
							|  |  |  |         "metadata": { | 
					
						
							| 
									
										
										
										
											2025-02-13 17:57:02 +08:00
										 |  |  |             **(request.state.metadata if hasattr(request.state, "metadata") else {}), | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  |             "task": str(TASKS.TITLE_GENERATION), | 
					
						
							|  |  |  |             "task_body": form_data, | 
					
						
							|  |  |  |             "chat_id": form_data.get("chat_id", None), | 
					
						
							|  |  |  |         }, | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-06-03 22:07:29 +08:00
										 |  |  |     # Process the payload through the pipeline | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         payload = await process_pipeline_inlet_filter(request, payload, user, models) | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							|  |  |  |         raise e | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         return await generate_chat_completion(request, form_data=payload, user=user) | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							|  |  |  |         log.error("Exception occurred", exc_info=True) | 
					
						
							|  |  |  |         return JSONResponse( | 
					
						
							|  |  |  |             status_code=status.HTTP_400_BAD_REQUEST, | 
					
						
							|  |  |  |             content={"detail": "An internal error has occurred."}, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @router.post("/follow_up/completions") | 
					
						
							|  |  |  | async def generate_follow_ups( | 
					
						
							|  |  |  |     request: Request, form_data: dict, user=Depends(get_verified_user) | 
					
						
							|  |  |  | ): | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if not request.app.state.config.ENABLE_FOLLOW_UP_GENERATION: | 
					
						
							|  |  |  |         return JSONResponse( | 
					
						
							|  |  |  |             status_code=status.HTTP_200_OK, | 
					
						
							|  |  |  |             content={"detail": "Follow-up generation is disabled"}, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if getattr(request.state, "direct", False) and hasattr(request.state, "model"): | 
					
						
							|  |  |  |         models = { | 
					
						
							|  |  |  |             request.state.model["id"]: request.state.model, | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         models = request.app.state.MODELS | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     model_id = form_data["model"] | 
					
						
							|  |  |  |     if model_id not in models: | 
					
						
							|  |  |  |         raise HTTPException( | 
					
						
							|  |  |  |             status_code=status.HTTP_404_NOT_FOUND, | 
					
						
							|  |  |  |             detail="Model not found", | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Check if the user has a custom task model | 
					
						
							|  |  |  |     # If the user has a custom task model, use that model | 
					
						
							|  |  |  |     task_model_id = get_task_model_id( | 
					
						
							|  |  |  |         model_id, | 
					
						
							|  |  |  |         request.app.state.config.TASK_MODEL, | 
					
						
							|  |  |  |         request.app.state.config.TASK_MODEL_EXTERNAL, | 
					
						
							|  |  |  |         models, | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     log.debug( | 
					
						
							|  |  |  |         f"generating chat title using model {task_model_id} for user {user.email} " | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE != "": | 
					
						
							|  |  |  |         template = request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         template = DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     content = follow_up_generation_template( | 
					
						
							|  |  |  |         template, | 
					
						
							|  |  |  |         form_data["messages"], | 
					
						
							|  |  |  |         { | 
					
						
							|  |  |  |             "name": user.name, | 
					
						
							|  |  |  |             "location": user.info.get("location") if user.info else None, | 
					
						
							|  |  |  |         }, | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     payload = { | 
					
						
							|  |  |  |         "model": task_model_id, | 
					
						
							|  |  |  |         "messages": [{"role": "user", "content": content}], | 
					
						
							|  |  |  |         "stream": False, | 
					
						
							|  |  |  |         "metadata": { | 
					
						
							|  |  |  |             **(request.state.metadata if hasattr(request.state, "metadata") else {}), | 
					
						
							|  |  |  |             "task": str(TASKS.FOLLOW_UP_GENERATION), | 
					
						
							|  |  |  |             "task_body": form_data, | 
					
						
							|  |  |  |             "chat_id": form_data.get("chat_id", None), | 
					
						
							|  |  |  |         }, | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-21 12:46:00 +08:00
										 |  |  |     # Process the payload through the pipeline | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         payload = await process_pipeline_inlet_filter(request, payload, user, models) | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							|  |  |  |         raise e | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  |     try: | 
					
						
							| 
									
										
										
										
											2024-12-13 14:28:42 +08:00
										 |  |  |         return await generate_chat_completion(request, form_data=payload, user=user) | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  |     except Exception as e: | 
					
						
							| 
									
										
										
										
											2024-12-17 03:59:50 +08:00
										 |  |  |         log.error("Exception occurred", exc_info=True) | 
					
						
							| 
									
										
										
										
											2024-12-13 14:28:42 +08:00
										 |  |  |         return JSONResponse( | 
					
						
							|  |  |  |             status_code=status.HTTP_400_BAD_REQUEST, | 
					
						
							| 
									
										
										
										
											2024-12-17 03:59:50 +08:00
										 |  |  |             content={"detail": "An internal error has occurred."}, | 
					
						
							| 
									
										
										
										
											2024-12-13 14:28:42 +08:00
										 |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @router.post("/tags/completions") | 
					
						
							| 
									
										
										
										
											2024-12-10 16:54:13 +08:00
										 |  |  | async def generate_chat_tags( | 
					
						
							|  |  |  |     request: Request, form_data: dict, user=Depends(get_verified_user) | 
					
						
							|  |  |  | ): | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     if not request.app.state.config.ENABLE_TAGS_GENERATION: | 
					
						
							|  |  |  |         return JSONResponse( | 
					
						
							|  |  |  |             status_code=status.HTTP_200_OK, | 
					
						
							|  |  |  |             content={"detail": "Tags generation is disabled"}, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-13 15:26:47 +08:00
										 |  |  |     if getattr(request.state, "direct", False) and hasattr(request.state, "model"): | 
					
						
							| 
									
										
										
										
											2025-02-13 14:56:33 +08:00
										 |  |  |         models = { | 
					
						
							|  |  |  |             request.state.model["id"]: request.state.model, | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         models = request.app.state.MODELS | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     model_id = form_data["model"] | 
					
						
							|  |  |  |     if model_id not in models: | 
					
						
							|  |  |  |         raise HTTPException( | 
					
						
							|  |  |  |             status_code=status.HTTP_404_NOT_FOUND, | 
					
						
							|  |  |  |             detail="Model not found", | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Check if the user has a custom task model | 
					
						
							|  |  |  |     # If the user has a custom task model, use that model | 
					
						
							|  |  |  |     task_model_id = get_task_model_id( | 
					
						
							|  |  |  |         model_id, | 
					
						
							|  |  |  |         request.app.state.config.TASK_MODEL, | 
					
						
							|  |  |  |         request.app.state.config.TASK_MODEL_EXTERNAL, | 
					
						
							|  |  |  |         models, | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     log.debug( | 
					
						
							|  |  |  |         f"generating chat tags using model {task_model_id} for user {user.email} " | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE != "": | 
					
						
							|  |  |  |         template = request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE | 
					
						
							|  |  |  |     else: | 
					
						
							| 
									
										
										
										
											2024-12-13 14:28:42 +08:00
										 |  |  |         template = DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     content = tags_generation_template( | 
					
						
							|  |  |  |         template, form_data["messages"], {"name": user.name} | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     payload = { | 
					
						
							|  |  |  |         "model": task_model_id, | 
					
						
							|  |  |  |         "messages": [{"role": "user", "content": content}], | 
					
						
							|  |  |  |         "stream": False, | 
					
						
							|  |  |  |         "metadata": { | 
					
						
							| 
									
										
										
										
											2025-02-13 17:57:02 +08:00
										 |  |  |             **(request.state.metadata if hasattr(request.state, "metadata") else {}), | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  |             "task": str(TASKS.TAGS_GENERATION), | 
					
						
							|  |  |  |             "task_body": form_data, | 
					
						
							|  |  |  |             "chat_id": form_data.get("chat_id", None), | 
					
						
							|  |  |  |         }, | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-21 12:46:00 +08:00
										 |  |  |     # Process the payload through the pipeline | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         payload = await process_pipeline_inlet_filter(request, payload, user, models) | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							|  |  |  |         raise e | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  |     try: | 
					
						
							| 
									
										
										
										
											2024-12-13 14:28:42 +08:00
										 |  |  |         return await generate_chat_completion(request, form_data=payload, user=user) | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  |     except Exception as e: | 
					
						
							| 
									
										
										
										
											2024-12-17 04:27:11 +08:00
										 |  |  |         log.error(f"Error generating chat completion: {e}") | 
					
						
							| 
									
										
										
										
											2024-12-13 14:28:42 +08:00
										 |  |  |         return JSONResponse( | 
					
						
							| 
									
										
										
										
											2024-12-17 04:27:11 +08:00
										 |  |  |             status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | 
					
						
							|  |  |  |             content={"detail": "An internal error has occurred."}, | 
					
						
							| 
									
										
										
										
											2024-12-13 14:28:42 +08:00
										 |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-16 16:06:37 +08:00
										 |  |  | @router.post("/image_prompt/completions") | 
					
						
							|  |  |  | async def generate_image_prompt( | 
					
						
							|  |  |  |     request: Request, form_data: dict, user=Depends(get_verified_user) | 
					
						
							|  |  |  | ): | 
					
						
							| 
									
										
										
										
											2025-02-13 15:26:47 +08:00
										 |  |  |     if getattr(request.state, "direct", False) and hasattr(request.state, "model"): | 
					
						
							| 
									
										
										
										
											2025-02-13 14:56:33 +08:00
										 |  |  |         models = { | 
					
						
							|  |  |  |             request.state.model["id"]: request.state.model, | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         models = request.app.state.MODELS | 
					
						
							| 
									
										
										
										
											2025-01-16 16:06:37 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     model_id = form_data["model"] | 
					
						
							|  |  |  |     if model_id not in models: | 
					
						
							|  |  |  |         raise HTTPException( | 
					
						
							|  |  |  |             status_code=status.HTTP_404_NOT_FOUND, | 
					
						
							|  |  |  |             detail="Model not found", | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Check if the user has a custom task model | 
					
						
							|  |  |  |     # If the user has a custom task model, use that model | 
					
						
							|  |  |  |     task_model_id = get_task_model_id( | 
					
						
							|  |  |  |         model_id, | 
					
						
							|  |  |  |         request.app.state.config.TASK_MODEL, | 
					
						
							|  |  |  |         request.app.state.config.TASK_MODEL_EXTERNAL, | 
					
						
							|  |  |  |         models, | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     log.debug( | 
					
						
							|  |  |  |         f"generating image prompt using model {task_model_id} for user {user.email} " | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE != "": | 
					
						
							|  |  |  |         template = request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         template = DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     content = image_prompt_generation_template( | 
					
						
							|  |  |  |         template, | 
					
						
							|  |  |  |         form_data["messages"], | 
					
						
							|  |  |  |         user={ | 
					
						
							|  |  |  |             "name": user.name, | 
					
						
							|  |  |  |         }, | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     payload = { | 
					
						
							|  |  |  |         "model": task_model_id, | 
					
						
							|  |  |  |         "messages": [{"role": "user", "content": content}], | 
					
						
							|  |  |  |         "stream": False, | 
					
						
							|  |  |  |         "metadata": { | 
					
						
							| 
									
										
										
										
											2025-02-13 17:57:02 +08:00
										 |  |  |             **(request.state.metadata if hasattr(request.state, "metadata") else {}), | 
					
						
							| 
									
										
										
										
											2025-01-16 16:06:37 +08:00
										 |  |  |             "task": str(TASKS.IMAGE_PROMPT_GENERATION), | 
					
						
							|  |  |  |             "task_body": form_data, | 
					
						
							|  |  |  |             "chat_id": form_data.get("chat_id", None), | 
					
						
							|  |  |  |         }, | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-21 12:46:00 +08:00
										 |  |  |     # Process the payload through the pipeline | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         payload = await process_pipeline_inlet_filter(request, payload, user, models) | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							|  |  |  |         raise e | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-16 16:06:37 +08:00
										 |  |  |     try: | 
					
						
							|  |  |  |         return await generate_chat_completion(request, form_data=payload, user=user) | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							|  |  |  |         log.error("Exception occurred", exc_info=True) | 
					
						
							|  |  |  |         return JSONResponse( | 
					
						
							|  |  |  |             status_code=status.HTTP_400_BAD_REQUEST, | 
					
						
							|  |  |  |             content={"detail": "An internal error has occurred."}, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  | @router.post("/queries/completions") | 
					
						
							|  |  |  | async def generate_queries( | 
					
						
							|  |  |  |     request: Request, form_data: dict, user=Depends(get_verified_user) | 
					
						
							|  |  |  | ): | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     type = form_data.get("type") | 
					
						
							|  |  |  |     if type == "web_search": | 
					
						
							|  |  |  |         if not request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION: | 
					
						
							|  |  |  |             raise HTTPException( | 
					
						
							|  |  |  |                 status_code=status.HTTP_400_BAD_REQUEST, | 
					
						
							|  |  |  |                 detail=f"Search query generation is disabled", | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |     elif type == "retrieval": | 
					
						
							|  |  |  |         if not request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION: | 
					
						
							|  |  |  |             raise HTTPException( | 
					
						
							|  |  |  |                 status_code=status.HTTP_400_BAD_REQUEST, | 
					
						
							|  |  |  |                 detail=f"Query generation is disabled", | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-13 15:26:47 +08:00
										 |  |  |     if getattr(request.state, "direct", False) and hasattr(request.state, "model"): | 
					
						
							| 
									
										
										
										
											2025-02-13 14:56:33 +08:00
										 |  |  |         models = { | 
					
						
							|  |  |  |             request.state.model["id"]: request.state.model, | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         models = request.app.state.MODELS | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     model_id = form_data["model"] | 
					
						
							|  |  |  |     if model_id not in models: | 
					
						
							|  |  |  |         raise HTTPException( | 
					
						
							|  |  |  |             status_code=status.HTTP_404_NOT_FOUND, | 
					
						
							|  |  |  |             detail="Model not found", | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Check if the user has a custom task model | 
					
						
							|  |  |  |     # If the user has a custom task model, use that model | 
					
						
							|  |  |  |     task_model_id = get_task_model_id( | 
					
						
							|  |  |  |         model_id, | 
					
						
							|  |  |  |         request.app.state.config.TASK_MODEL, | 
					
						
							|  |  |  |         request.app.state.config.TASK_MODEL_EXTERNAL, | 
					
						
							|  |  |  |         models, | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     log.debug( | 
					
						
							|  |  |  |         f"generating {type} queries using model {task_model_id} for user {user.email}" | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if (request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE).strip() != "": | 
					
						
							|  |  |  |         template = request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         template = DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     content = query_generation_template( | 
					
						
							|  |  |  |         template, form_data["messages"], {"name": user.name} | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     payload = { | 
					
						
							|  |  |  |         "model": task_model_id, | 
					
						
							|  |  |  |         "messages": [{"role": "user", "content": content}], | 
					
						
							|  |  |  |         "stream": False, | 
					
						
							|  |  |  |         "metadata": { | 
					
						
							| 
									
										
										
										
											2025-02-13 17:57:02 +08:00
										 |  |  |             **(request.state.metadata if hasattr(request.state, "metadata") else {}), | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  |             "task": str(TASKS.QUERY_GENERATION), | 
					
						
							|  |  |  |             "task_body": form_data, | 
					
						
							|  |  |  |             "chat_id": form_data.get("chat_id", None), | 
					
						
							|  |  |  |         }, | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-21 12:46:00 +08:00
										 |  |  |     # Process the payload through the pipeline | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         payload = await process_pipeline_inlet_filter(request, payload, user, models) | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							|  |  |  |         raise e | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  |     try: | 
					
						
							| 
									
										
										
										
											2024-12-13 14:28:42 +08:00
										 |  |  |         return await generate_chat_completion(request, form_data=payload, user=user) | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  |     except Exception as e: | 
					
						
							| 
									
										
										
										
											2024-12-13 14:28:42 +08:00
										 |  |  |         return JSONResponse( | 
					
						
							|  |  |  |             status_code=status.HTTP_400_BAD_REQUEST, | 
					
						
							|  |  |  |             content={"detail": str(e)}, | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @router.post("/auto/completions") | 
					
						
							|  |  |  | async def generate_autocompletion( | 
					
						
							|  |  |  |     request: Request, form_data: dict, user=Depends(get_verified_user) | 
					
						
							|  |  |  | ): | 
					
						
							|  |  |  |     if not request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION: | 
					
						
							|  |  |  |         raise HTTPException( | 
					
						
							|  |  |  |             status_code=status.HTTP_400_BAD_REQUEST, | 
					
						
							|  |  |  |             detail=f"Autocompletion generation is disabled", | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     type = form_data.get("type") | 
					
						
							|  |  |  |     prompt = form_data.get("prompt") | 
					
						
							|  |  |  |     messages = form_data.get("messages") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH > 0: | 
					
						
							|  |  |  |         if ( | 
					
						
							|  |  |  |             len(prompt) | 
					
						
							|  |  |  |             > request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH | 
					
						
							|  |  |  |         ): | 
					
						
							|  |  |  |             raise HTTPException( | 
					
						
							|  |  |  |                 status_code=status.HTTP_400_BAD_REQUEST, | 
					
						
							|  |  |  |                 detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}", | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-13 15:26:47 +08:00
										 |  |  |     if getattr(request.state, "direct", False) and hasattr(request.state, "model"): | 
					
						
							| 
									
										
										
										
											2025-02-13 14:56:33 +08:00
										 |  |  |         models = { | 
					
						
							|  |  |  |             request.state.model["id"]: request.state.model, | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         models = request.app.state.MODELS | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     model_id = form_data["model"] | 
					
						
							|  |  |  |     if model_id not in models: | 
					
						
							|  |  |  |         raise HTTPException( | 
					
						
							|  |  |  |             status_code=status.HTTP_404_NOT_FOUND, | 
					
						
							|  |  |  |             detail="Model not found", | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Check if the user has a custom task model | 
					
						
							|  |  |  |     # If the user has a custom task model, use that model | 
					
						
							|  |  |  |     task_model_id = get_task_model_id( | 
					
						
							|  |  |  |         model_id, | 
					
						
							|  |  |  |         request.app.state.config.TASK_MODEL, | 
					
						
							|  |  |  |         request.app.state.config.TASK_MODEL_EXTERNAL, | 
					
						
							|  |  |  |         models, | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     log.debug( | 
					
						
							|  |  |  |         f"generating autocompletion using model {task_model_id} for user {user.email}" | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if (request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE).strip() != "": | 
					
						
							|  |  |  |         template = request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         template = DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     content = autocomplete_generation_template( | 
					
						
							|  |  |  |         template, prompt, messages, type, {"name": user.name} | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     payload = { | 
					
						
							|  |  |  |         "model": task_model_id, | 
					
						
							|  |  |  |         "messages": [{"role": "user", "content": content}], | 
					
						
							|  |  |  |         "stream": False, | 
					
						
							|  |  |  |         "metadata": { | 
					
						
							| 
									
										
										
										
											2025-02-13 17:57:02 +08:00
										 |  |  |             **(request.state.metadata if hasattr(request.state, "metadata") else {}), | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  |             "task": str(TASKS.AUTOCOMPLETE_GENERATION), | 
					
						
							|  |  |  |             "task_body": form_data, | 
					
						
							|  |  |  |             "chat_id": form_data.get("chat_id", None), | 
					
						
							|  |  |  |         }, | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-21 12:46:00 +08:00
										 |  |  |     # Process the payload through the pipeline | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         payload = await process_pipeline_inlet_filter(request, payload, user, models) | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							|  |  |  |         raise e | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  |     try: | 
					
						
							| 
									
										
										
										
											2024-12-13 14:28:42 +08:00
										 |  |  |         return await generate_chat_completion(request, form_data=payload, user=user) | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  |     except Exception as e: | 
					
						
							| 
									
										
										
										
											2024-12-20 07:59:03 +08:00
										 |  |  |         log.error(f"Error generating chat completion: {e}") | 
					
						
							| 
									
										
										
										
											2024-12-13 14:28:42 +08:00
										 |  |  |         return JSONResponse( | 
					
						
							| 
									
										
										
										
											2024-12-20 07:59:03 +08:00
										 |  |  |             status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | 
					
						
							|  |  |  |             content={"detail": "An internal error has occurred."}, | 
					
						
							| 
									
										
										
										
											2024-12-13 14:28:42 +08:00
										 |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @router.post("/emoji/completions") | 
					
						
							|  |  |  | async def generate_emoji( | 
					
						
							|  |  |  |     request: Request, form_data: dict, user=Depends(get_verified_user) | 
					
						
							|  |  |  | ): | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-13 15:26:47 +08:00
										 |  |  |     if getattr(request.state, "direct", False) and hasattr(request.state, "model"): | 
					
						
							| 
									
										
										
										
											2025-02-13 14:56:33 +08:00
										 |  |  |         models = { | 
					
						
							|  |  |  |             request.state.model["id"]: request.state.model, | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         models = request.app.state.MODELS | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     model_id = form_data["model"] | 
					
						
							|  |  |  |     if model_id not in models: | 
					
						
							|  |  |  |         raise HTTPException( | 
					
						
							|  |  |  |             status_code=status.HTTP_404_NOT_FOUND, | 
					
						
							|  |  |  |             detail="Model not found", | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Check if the user has a custom task model | 
					
						
							|  |  |  |     # If the user has a custom task model, use that model | 
					
						
							|  |  |  |     task_model_id = get_task_model_id( | 
					
						
							|  |  |  |         model_id, | 
					
						
							|  |  |  |         request.app.state.config.TASK_MODEL, | 
					
						
							|  |  |  |         request.app.state.config.TASK_MODEL_EXTERNAL, | 
					
						
							|  |  |  |         models, | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     log.debug(f"generating emoji using model {task_model_id} for user {user.email} ") | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-13 14:28:42 +08:00
										 |  |  |     template = DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     content = emoji_generation_template( | 
					
						
							|  |  |  |         template, | 
					
						
							|  |  |  |         form_data["prompt"], | 
					
						
							|  |  |  |         { | 
					
						
							|  |  |  |             "name": user.name, | 
					
						
							|  |  |  |             "location": user.info.get("location") if user.info else None, | 
					
						
							|  |  |  |         }, | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     payload = { | 
					
						
							|  |  |  |         "model": task_model_id, | 
					
						
							|  |  |  |         "messages": [{"role": "user", "content": content}], | 
					
						
							|  |  |  |         "stream": False, | 
					
						
							|  |  |  |         **( | 
					
						
							|  |  |  |             {"max_tokens": 4} | 
					
						
							| 
									
										
										
										
											2025-02-16 08:41:41 +08:00
										 |  |  |             if models[task_model_id].get("owned_by") == "ollama" | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  |             else { | 
					
						
							|  |  |  |                 "max_completion_tokens": 4, | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         ), | 
					
						
							| 
									
										
										
										
											2025-02-13 14:56:33 +08:00
										 |  |  |         "metadata": { | 
					
						
							| 
									
										
										
										
											2025-02-13 17:57:02 +08:00
										 |  |  |             **(request.state.metadata if hasattr(request.state, "metadata") else {}), | 
					
						
							| 
									
										
										
										
											2025-02-13 14:56:33 +08:00
										 |  |  |             "task": str(TASKS.EMOJI_GENERATION), | 
					
						
							|  |  |  |             "task_body": form_data, | 
					
						
							| 
									
										
										
										
											2025-07-16 19:38:48 +08:00
										 |  |  |             "chat_id": form_data.get("chat_id", None), | 
					
						
							| 
									
										
										
										
											2025-02-13 14:56:33 +08:00
										 |  |  |         }, | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-21 12:46:00 +08:00
										 |  |  |     # Process the payload through the pipeline | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         payload = await process_pipeline_inlet_filter(request, payload, user, models) | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							|  |  |  |         raise e | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  |     try: | 
					
						
							| 
									
										
										
										
											2024-12-13 14:28:42 +08:00
										 |  |  |         return await generate_chat_completion(request, form_data=payload, user=user) | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  |     except Exception as e: | 
					
						
							| 
									
										
										
										
											2024-12-13 14:28:42 +08:00
										 |  |  |         return JSONResponse( | 
					
						
							|  |  |  |             status_code=status.HTTP_400_BAD_REQUEST, | 
					
						
							|  |  |  |             content={"detail": str(e)}, | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @router.post("/moa/completions") | 
					
						
							|  |  |  | async def generate_moa_response( | 
					
						
							|  |  |  |     request: Request, form_data: dict, user=Depends(get_verified_user) | 
					
						
							|  |  |  | ): | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-13 15:26:47 +08:00
										 |  |  |     if getattr(request.state, "direct", False) and hasattr(request.state, "model"): | 
					
						
							| 
									
										
										
										
											2025-02-13 14:56:33 +08:00
										 |  |  |         models = { | 
					
						
							|  |  |  |             request.state.model["id"]: request.state.model, | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         models = request.app.state.MODELS | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  |     model_id = form_data["model"] | 
					
						
							| 
									
										
										
										
											2024-12-12 11:52:46 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  |     if model_id not in models: | 
					
						
							|  |  |  |         raise HTTPException( | 
					
						
							|  |  |  |             status_code=status.HTTP_404_NOT_FOUND, | 
					
						
							|  |  |  |             detail="Model not found", | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-13 14:28:42 +08:00
										 |  |  |     template = DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     content = moa_response_generation_template( | 
					
						
							|  |  |  |         template, | 
					
						
							|  |  |  |         form_data["prompt"], | 
					
						
							|  |  |  |         form_data["responses"], | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     payload = { | 
					
						
							| 
									
										
										
										
											2025-04-03 10:33:20 +08:00
										 |  |  |         "model": model_id, | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  |         "messages": [{"role": "user", "content": content}], | 
					
						
							|  |  |  |         "stream": form_data.get("stream", False), | 
					
						
							|  |  |  |         "metadata": { | 
					
						
							| 
									
										
										
										
											2025-02-13 17:57:02 +08:00
										 |  |  |             **(request.state.metadata if hasattr(request.state, "metadata") else {}), | 
					
						
							| 
									
										
										
										
											2025-01-03 06:32:25 +08:00
										 |  |  |             "chat_id": form_data.get("chat_id", None), | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  |             "task": str(TASKS.MOA_RESPONSE_GENERATION), | 
					
						
							|  |  |  |             "task_body": form_data, | 
					
						
							|  |  |  |         }, | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-21 12:46:00 +08:00
										 |  |  |     # Process the payload through the pipeline | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         payload = await process_pipeline_inlet_filter(request, payload, user, models) | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							|  |  |  |         raise e | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  |     try: | 
					
						
							| 
									
										
										
										
											2024-12-13 14:28:42 +08:00
										 |  |  |         return await generate_chat_completion(request, form_data=payload, user=user) | 
					
						
							| 
									
										
										
										
											2024-12-10 16:00:01 +08:00
										 |  |  |     except Exception as e: | 
					
						
							| 
									
										
										
										
											2024-12-13 14:28:42 +08:00
										 |  |  |         return JSONResponse( | 
					
						
							|  |  |  |             status_code=status.HTTP_400_BAD_REQUEST, | 
					
						
							|  |  |  |             content={"detail": str(e)}, | 
					
						
							|  |  |  |         ) |