| 
									
										
										
										
											2024-12-13 12:22:17 +08:00
										 |  |  | import time | 
					
						
							|  |  |  | import logging | 
					
						
							|  |  |  | import sys | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from aiocache import cached | 
					
						
							| 
									
										
										
										
											2024-12-13 14:28:42 +08:00
										 |  |  | from typing import Any, Optional | 
					
						
							| 
									
										
										
										
											2024-12-13 12:22:17 +08:00
										 |  |  | import random | 
					
						
							|  |  |  | import json | 
					
						
							|  |  |  | import inspect | 
					
						
							| 
									
										
										
										
											2025-02-13 14:56:33 +08:00
										 |  |  | import uuid | 
					
						
							|  |  |  | import asyncio | 
					
						
							| 
									
										
										
										
											2024-12-13 12:22:17 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-13 15:21:16 +08:00
										 |  |  | from fastapi import Request, status | 
					
						
							|  |  |  | from starlette.responses import Response, StreamingResponse, JSONResponse | 
					
						
							| 
									
										
										
										
											2024-12-13 12:22:17 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-13 14:28:42 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | from open_webui.models.users import UserModel | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-13 12:22:17 +08:00
										 |  |  | from open_webui.socket.main import ( | 
					
						
							| 
									
										
										
										
											2025-02-13 14:56:33 +08:00
										 |  |  |     sio, | 
					
						
							| 
									
										
										
										
											2024-12-13 12:22:17 +08:00
										 |  |  |     get_event_call, | 
					
						
							|  |  |  |     get_event_emitter, | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | from open_webui.functions import generate_function_chat_completion | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from open_webui.routers.openai import ( | 
					
						
							|  |  |  |     generate_chat_completion as generate_openai_chat_completion, | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from open_webui.routers.ollama import ( | 
					
						
							|  |  |  |     generate_chat_completion as generate_ollama_chat_completion, | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from open_webui.routers.pipelines import ( | 
					
						
							| 
									
										
										
										
											2024-12-13 14:28:42 +08:00
										 |  |  |     process_pipeline_inlet_filter, | 
					
						
							| 
									
										
										
										
											2024-12-13 12:22:17 +08:00
										 |  |  |     process_pipeline_outlet_filter, | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from open_webui.models.functions import Functions | 
					
						
							|  |  |  | from open_webui.models.models import Models | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from open_webui.utils.plugin import load_function_module_by_id | 
					
						
							| 
									
										
										
										
											2024-12-13 14:28:42 +08:00
										 |  |  | from open_webui.utils.models import get_all_models, check_model_access | 
					
						
							| 
									
										
										
										
											2024-12-13 12:22:17 +08:00
										 |  |  | from open_webui.utils.payload import convert_payload_openai_to_ollama | 
					
						
							|  |  |  | from open_webui.utils.response import ( | 
					
						
							|  |  |  |     convert_response_ollama_to_openai, | 
					
						
							|  |  |  |     convert_streaming_response_ollama_to_openai, | 
					
						
							|  |  |  | ) | 
					
						
							| 
									
										
										
										
											2025-02-07 07:01:43 +08:00
										 |  |  | from open_webui.utils.filter import ( | 
					
						
							|  |  |  |     get_sorted_filter_ids, | 
					
						
							|  |  |  |     process_filter_functions, | 
					
						
							|  |  |  | ) | 
					
						
							| 
									
										
										
										
											2024-12-13 14:28:42 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-13 12:22:17 +08:00
										 |  |  | from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) | 
					
						
							|  |  |  | log = logging.getLogger(__name__) | 
					
						
							|  |  |  | log.setLevel(SRC_LOG_LEVELS["MAIN"]) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-13 14:56:33 +08:00
										 |  |  | async def generate_direct_chat_completion( | 
					
						
							|  |  |  |     request: Request, | 
					
						
							|  |  |  |     form_data: dict, | 
					
						
							|  |  |  |     user: Any, | 
					
						
							|  |  |  |     models: dict, | 
					
						
							|  |  |  | ): | 
					
						
							| 
									
										
										
										
											2025-02-25 22:36:25 +08:00
										 |  |  |     log.info("generate_direct_chat_completion") | 
					
						
							| 
									
										
										
										
											2025-02-13 14:56:33 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     metadata = form_data.pop("metadata", {}) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     user_id = metadata.get("user_id") | 
					
						
							|  |  |  |     session_id = metadata.get("session_id") | 
					
						
							|  |  |  |     request_id = str(uuid.uuid4())  # Generate a unique request ID | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     event_caller = get_event_call(metadata) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     channel = f"{user_id}:{session_id}:{request_id}" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if form_data.get("stream"): | 
					
						
							|  |  |  |         q = asyncio.Queue() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-13 15:21:16 +08:00
										 |  |  |         async def message_listener(sid, data): | 
					
						
							|  |  |  |             """
 | 
					
						
							|  |  |  |             Handle received socket messages and push them into the queue. | 
					
						
							|  |  |  |             """
 | 
					
						
							|  |  |  |             await q.put(data) | 
					
						
							| 
									
										
										
										
											2025-02-13 14:56:33 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-13 15:21:16 +08:00
										 |  |  |         # Register the listener | 
					
						
							|  |  |  |         sio.on(channel, message_listener) | 
					
						
							| 
									
										
										
										
											2025-02-13 14:56:33 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-13 15:21:16 +08:00
										 |  |  |         # Start processing chat completion in background | 
					
						
							|  |  |  |         res = await event_caller( | 
					
						
							|  |  |  |             { | 
					
						
							|  |  |  |                 "type": "request:chat:completion", | 
					
						
							|  |  |  |                 "data": { | 
					
						
							|  |  |  |                     "form_data": form_data, | 
					
						
							|  |  |  |                     "model": models[form_data["model"]], | 
					
						
							|  |  |  |                     "channel": channel, | 
					
						
							|  |  |  |                     "session_id": session_id, | 
					
						
							|  |  |  |                 }, | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2025-02-13 14:56:33 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-25 22:36:25 +08:00
										 |  |  |         log.info(f"res: {res}") | 
					
						
							| 
									
										
										
										
											2025-02-13 15:21:16 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         if res.get("status", False): | 
					
						
							|  |  |  |             # Define a generator to stream responses | 
					
						
							|  |  |  |             async def event_generator(): | 
					
						
							|  |  |  |                 nonlocal q | 
					
						
							|  |  |  |                 try: | 
					
						
							|  |  |  |                     while True: | 
					
						
							|  |  |  |                         data = await q.get()  # Wait for new messages | 
					
						
							|  |  |  |                         if isinstance(data, dict): | 
					
						
							|  |  |  |                             if "done" in data and data["done"]: | 
					
						
							|  |  |  |                                 break  # Stop streaming when 'done' is received | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                             yield f"data: {json.dumps(data)}\n\n" | 
					
						
							|  |  |  |                         elif isinstance(data, str): | 
					
						
							|  |  |  |                             yield data | 
					
						
							|  |  |  |                 except Exception as e: | 
					
						
							|  |  |  |                     log.debug(f"Error in event generator: {e}") | 
					
						
							|  |  |  |                     pass | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # Define a background task to run the event generator | 
					
						
							|  |  |  |             async def background(): | 
					
						
							|  |  |  |                 try: | 
					
						
							|  |  |  |                     del sio.handlers["/"][channel] | 
					
						
							|  |  |  |                 except Exception as e: | 
					
						
							|  |  |  |                     pass | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # Return the streaming response | 
					
						
							|  |  |  |             return StreamingResponse( | 
					
						
							|  |  |  |                 event_generator(), media_type="text/event-stream", background=background | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             raise Exception(str(res)) | 
					
						
							| 
									
										
										
										
											2025-02-13 14:56:33 +08:00
										 |  |  |     else: | 
					
						
							|  |  |  |         res = await event_caller( | 
					
						
							|  |  |  |             { | 
					
						
							|  |  |  |                 "type": "request:chat:completion", | 
					
						
							|  |  |  |                 "data": { | 
					
						
							|  |  |  |                     "form_data": form_data, | 
					
						
							|  |  |  |                     "model": models[form_data["model"]], | 
					
						
							|  |  |  |                     "channel": channel, | 
					
						
							|  |  |  |                     "session_id": session_id, | 
					
						
							|  |  |  |                 }, | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-03-04 00:08:49 +08:00
										 |  |  |         if "error" in res and res["error"]: | 
					
						
							| 
									
										
										
										
											2025-02-13 14:56:33 +08:00
										 |  |  |             raise Exception(res["error"]) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return res | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-13 12:22:17 +08:00
										 |  |  | async def generate_chat_completion( | 
					
						
							|  |  |  |     request: Request, | 
					
						
							|  |  |  |     form_data: dict, | 
					
						
							|  |  |  |     user: Any, | 
					
						
							|  |  |  |     bypass_filter: bool = False, | 
					
						
							|  |  |  | ): | 
					
						
							| 
									
										
										
										
											2025-02-14 07:17:41 +08:00
										 |  |  |     log.debug(f"generate_chat_completion: {form_data}") | 
					
						
							| 
									
										
										
										
											2024-12-13 12:22:17 +08:00
										 |  |  |     if BYPASS_MODEL_ACCESS_CONTROL: | 
					
						
							|  |  |  |         bypass_filter = True | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-13 16:34:45 +08:00
										 |  |  |     if hasattr(request.state, "metadata"): | 
					
						
							| 
									
										
										
										
											2025-02-14 09:06:55 +08:00
										 |  |  |         if "metadata" not in form_data: | 
					
						
							|  |  |  |             form_data["metadata"] = request.state.metadata | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             form_data["metadata"] = { | 
					
						
							|  |  |  |                 **form_data["metadata"], | 
					
						
							|  |  |  |                 **request.state.metadata, | 
					
						
							|  |  |  |             } | 
					
						
							| 
									
										
										
										
											2025-02-13 16:34:45 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-13 15:26:47 +08:00
										 |  |  |     if getattr(request.state, "direct", False) and hasattr(request.state, "model"): | 
					
						
							| 
									
										
										
										
											2025-02-13 14:56:33 +08:00
										 |  |  |         models = { | 
					
						
							|  |  |  |             request.state.model["id"]: request.state.model, | 
					
						
							|  |  |  |         } | 
					
						
							| 
									
										
										
										
											2025-02-13 16:34:45 +08:00
										 |  |  |         log.debug(f"direct connection to model: {models}") | 
					
						
							| 
									
										
										
										
											2025-02-13 14:56:33 +08:00
										 |  |  |     else: | 
					
						
							|  |  |  |         models = request.app.state.MODELS | 
					
						
							| 
									
										
										
										
											2024-12-13 12:22:17 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     model_id = form_data["model"] | 
					
						
							|  |  |  |     if model_id not in models: | 
					
						
							|  |  |  |         raise Exception("Model not found") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     model = models[model_id] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-13 15:49:00 +08:00
										 |  |  |     if getattr(request.state, "direct", False): | 
					
						
							| 
									
										
										
										
											2025-02-13 14:56:33 +08:00
										 |  |  |         return await generate_direct_chat_completion( | 
					
						
							|  |  |  |             request, form_data, user=user, models=models | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |     else: | 
					
						
							| 
									
										
										
										
											2025-02-14 06:21:34 +08:00
										 |  |  |         # Check if user has access to the model | 
					
						
							|  |  |  |         if not bypass_filter and user.role == "user": | 
					
						
							|  |  |  |             try: | 
					
						
							|  |  |  |                 check_model_access(user, model) | 
					
						
							|  |  |  |             except Exception as e: | 
					
						
							|  |  |  |                 raise e | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-16 08:41:41 +08:00
										 |  |  |         if model.get("owned_by") == "arena": | 
					
						
							| 
									
										
										
										
											2025-02-13 14:56:33 +08:00
										 |  |  |             model_ids = model.get("info", {}).get("meta", {}).get("model_ids") | 
					
						
							|  |  |  |             filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode") | 
					
						
							|  |  |  |             if model_ids and filter_mode == "exclude": | 
					
						
							|  |  |  |                 model_ids = [ | 
					
						
							|  |  |  |                     model["id"] | 
					
						
							|  |  |  |                     for model in list(request.app.state.MODELS.values()) | 
					
						
							|  |  |  |                     if model.get("owned_by") != "arena" and model["id"] not in model_ids | 
					
						
							|  |  |  |                 ] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             selected_model_id = None | 
					
						
							|  |  |  |             if isinstance(model_ids, list) and model_ids: | 
					
						
							|  |  |  |                 selected_model_id = random.choice(model_ids) | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 model_ids = [ | 
					
						
							|  |  |  |                     model["id"] | 
					
						
							|  |  |  |                     for model in list(request.app.state.MODELS.values()) | 
					
						
							|  |  |  |                     if model.get("owned_by") != "arena" | 
					
						
							|  |  |  |                 ] | 
					
						
							|  |  |  |                 selected_model_id = random.choice(model_ids) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             form_data["model"] = selected_model_id | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             if form_data.get("stream") == True: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 async def stream_wrapper(stream): | 
					
						
							|  |  |  |                     yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n" | 
					
						
							|  |  |  |                     async for chunk in stream: | 
					
						
							|  |  |  |                         yield chunk | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 response = await generate_chat_completion( | 
					
						
							|  |  |  |                     request, form_data, user, bypass_filter=True | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |                 return StreamingResponse( | 
					
						
							|  |  |  |                     stream_wrapper(response.body_iterator), | 
					
						
							|  |  |  |                     media_type="text/event-stream", | 
					
						
							|  |  |  |                     background=response.background, | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 return { | 
					
						
							|  |  |  |                     **( | 
					
						
							|  |  |  |                         await generate_chat_completion( | 
					
						
							|  |  |  |                             request, form_data, user, bypass_filter=True | 
					
						
							|  |  |  |                         ) | 
					
						
							|  |  |  |                     ), | 
					
						
							|  |  |  |                     "selected_model_id": selected_model_id, | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if model.get("pipe"): | 
					
						
							|  |  |  |             # Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter | 
					
						
							|  |  |  |             return await generate_function_chat_completion( | 
					
						
							|  |  |  |                 request, form_data, user=user, models=models | 
					
						
							| 
									
										
										
										
											2024-12-13 12:22:17 +08:00
										 |  |  |             ) | 
					
						
							| 
									
										
										
										
											2025-02-16 08:41:41 +08:00
										 |  |  |         if model.get("owned_by") == "ollama": | 
					
						
							| 
									
										
										
										
											2025-02-13 14:56:33 +08:00
										 |  |  |             # Using /ollama/api/chat endpoint | 
					
						
							|  |  |  |             form_data = convert_payload_openai_to_ollama(form_data) | 
					
						
							|  |  |  |             response = await generate_ollama_chat_completion( | 
					
						
							|  |  |  |                 request=request, | 
					
						
							|  |  |  |                 form_data=form_data, | 
					
						
							|  |  |  |                 user=user, | 
					
						
							|  |  |  |                 bypass_filter=bypass_filter, | 
					
						
							| 
									
										
										
										
											2024-12-13 12:22:17 +08:00
										 |  |  |             ) | 
					
						
							| 
									
										
										
										
											2025-02-13 14:56:33 +08:00
										 |  |  |             if form_data.get("stream"): | 
					
						
							|  |  |  |                 response.headers["content-type"] = "text/event-stream" | 
					
						
							|  |  |  |                 return StreamingResponse( | 
					
						
							|  |  |  |                     convert_streaming_response_ollama_to_openai(response), | 
					
						
							|  |  |  |                     headers=dict(response.headers), | 
					
						
							|  |  |  |                     background=response.background, | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 return convert_response_ollama_to_openai(response) | 
					
						
							| 
									
										
										
										
											2024-12-13 12:22:17 +08:00
										 |  |  |         else: | 
					
						
							| 
									
										
										
										
											2025-02-13 14:56:33 +08:00
										 |  |  |             return await generate_openai_chat_completion( | 
					
						
							|  |  |  |                 request=request, | 
					
						
							|  |  |  |                 form_data=form_data, | 
					
						
							|  |  |  |                 user=user, | 
					
						
							|  |  |  |                 bypass_filter=bypass_filter, | 
					
						
							| 
									
										
										
										
											2024-12-13 12:22:17 +08:00
										 |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-29 08:42:43 +08:00
										 |  |  | chat_completion = generate_chat_completion | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-13 12:22:17 +08:00
										 |  |  | async def chat_completed(request: Request, form_data: dict, user: Any): | 
					
						
							| 
									
										
										
										
											2024-12-22 00:59:12 +08:00
										 |  |  |     if not request.app.state.MODELS: | 
					
						
							| 
									
										
										
										
											2025-02-13 23:29:26 +08:00
										 |  |  |         await get_all_models(request, user=user) | 
					
						
							| 
									
										
										
										
											2025-02-13 14:56:33 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-13 15:26:47 +08:00
										 |  |  |     if getattr(request.state, "direct", False) and hasattr(request.state, "model"): | 
					
						
							| 
									
										
										
										
											2025-02-13 14:56:33 +08:00
										 |  |  |         models = { | 
					
						
							|  |  |  |             request.state.model["id"]: request.state.model, | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         models = request.app.state.MODELS | 
					
						
							| 
									
										
										
										
											2024-12-13 12:22:17 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     data = form_data | 
					
						
							|  |  |  |     model_id = data["model"] | 
					
						
							|  |  |  |     if model_id not in models: | 
					
						
							|  |  |  |         raise Exception("Model not found") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     model = models[model_id] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     try: | 
					
						
							| 
									
										
										
										
											2025-02-16 14:25:18 +08:00
										 |  |  |         data = await process_pipeline_outlet_filter(request, data, user, models) | 
					
						
							| 
									
										
										
										
											2024-12-13 12:22:17 +08:00
										 |  |  |     except Exception as e: | 
					
						
							|  |  |  |         return Exception(f"Error: {e}") | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-07 07:01:43 +08:00
										 |  |  |     metadata = { | 
					
						
							|  |  |  |         "chat_id": data["chat_id"], | 
					
						
							|  |  |  |         "message_id": data["id"], | 
					
						
							|  |  |  |         "session_id": data["session_id"], | 
					
						
							|  |  |  |         "user_id": user.id, | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     extra_params = { | 
					
						
							|  |  |  |         "__event_emitter__": get_event_emitter(metadata), | 
					
						
							|  |  |  |         "__event_call__": get_event_call(metadata), | 
					
						
							|  |  |  |         "__user__": { | 
					
						
							|  |  |  |             "id": user.id, | 
					
						
							|  |  |  |             "email": user.email, | 
					
						
							|  |  |  |             "name": user.name, | 
					
						
							|  |  |  |             "role": user.role, | 
					
						
							|  |  |  |         }, | 
					
						
							|  |  |  |         "__metadata__": metadata, | 
					
						
							|  |  |  |         "__request__": request, | 
					
						
							| 
									
										
										
										
											2025-02-08 17:07:05 +08:00
										 |  |  |         "__model__": model, | 
					
						
							| 
									
										
										
										
											2025-02-07 07:01:43 +08:00
										 |  |  |     } | 
					
						
							| 
									
										
										
										
											2024-12-13 12:22:17 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-07 07:01:43 +08:00
										 |  |  |     try: | 
					
						
							|  |  |  |         result, _ = await process_filter_functions( | 
					
						
							|  |  |  |             request=request, | 
					
						
							| 
									
										
										
										
											2025-02-08 14:57:39 +08:00
										 |  |  |             filter_ids=get_sorted_filter_ids(model), | 
					
						
							|  |  |  |             filter_type="outlet", | 
					
						
							|  |  |  |             form_data=data, | 
					
						
							| 
									
										
										
										
											2025-02-07 07:01:43 +08:00
										 |  |  |             extra_params=extra_params, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         return result | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							|  |  |  |         return Exception(f"Error: {e}") | 
					
						
							| 
									
										
										
										
											2024-12-13 12:22:17 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | async def chat_action(request: Request, action_id: str, form_data: dict, user: Any): | 
					
						
							|  |  |  |     if "." in action_id: | 
					
						
							|  |  |  |         action_id, sub_action_id = action_id.split(".") | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         sub_action_id = None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     action = Functions.get_function_by_id(action_id) | 
					
						
							|  |  |  |     if not action: | 
					
						
							|  |  |  |         raise Exception(f"Action not found: {action_id}") | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-22 00:59:12 +08:00
										 |  |  |     if not request.app.state.MODELS: | 
					
						
							| 
									
										
										
										
											2025-02-13 23:29:26 +08:00
										 |  |  |         await get_all_models(request, user=user) | 
					
						
							| 
									
										
										
										
											2025-02-13 14:56:33 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-13 15:26:47 +08:00
										 |  |  |     if getattr(request.state, "direct", False) and hasattr(request.state, "model"): | 
					
						
							| 
									
										
										
										
											2025-02-13 14:56:33 +08:00
										 |  |  |         models = { | 
					
						
							|  |  |  |             request.state.model["id"]: request.state.model, | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         models = request.app.state.MODELS | 
					
						
							| 
									
										
										
										
											2024-12-13 12:22:17 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     data = form_data | 
					
						
							|  |  |  |     model_id = data["model"] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if model_id not in models: | 
					
						
							|  |  |  |         raise Exception("Model not found") | 
					
						
							|  |  |  |     model = models[model_id] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     __event_emitter__ = get_event_emitter( | 
					
						
							|  |  |  |         { | 
					
						
							|  |  |  |             "chat_id": data["chat_id"], | 
					
						
							|  |  |  |             "message_id": data["id"], | 
					
						
							|  |  |  |             "session_id": data["session_id"], | 
					
						
							| 
									
										
										
										
											2025-01-04 13:31:24 +08:00
										 |  |  |             "user_id": user.id, | 
					
						
							| 
									
										
										
										
											2024-12-13 12:22:17 +08:00
										 |  |  |         } | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     __event_call__ = get_event_call( | 
					
						
							|  |  |  |         { | 
					
						
							|  |  |  |             "chat_id": data["chat_id"], | 
					
						
							|  |  |  |             "message_id": data["id"], | 
					
						
							|  |  |  |             "session_id": data["session_id"], | 
					
						
							| 
									
										
										
										
											2025-01-04 13:31:24 +08:00
										 |  |  |             "user_id": user.id, | 
					
						
							| 
									
										
										
										
											2024-12-13 12:22:17 +08:00
										 |  |  |         } | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if action_id in request.app.state.FUNCTIONS: | 
					
						
							|  |  |  |         function_module = request.app.state.FUNCTIONS[action_id] | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         function_module, _, _ = load_function_module_by_id(action_id) | 
					
						
							|  |  |  |         request.app.state.FUNCTIONS[action_id] = function_module | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): | 
					
						
							|  |  |  |         valves = Functions.get_function_valves_by_id(action_id) | 
					
						
							|  |  |  |         function_module.valves = function_module.Valves(**(valves if valves else {})) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if hasattr(function_module, "action"): | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             action = function_module.action | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # Get the signature of the function | 
					
						
							|  |  |  |             sig = inspect.signature(action) | 
					
						
							|  |  |  |             params = {"body": data} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # Extra parameters to be passed to the function | 
					
						
							|  |  |  |             extra_params = { | 
					
						
							|  |  |  |                 "__model__": model, | 
					
						
							|  |  |  |                 "__id__": sub_action_id if sub_action_id is not None else action_id, | 
					
						
							|  |  |  |                 "__event_emitter__": __event_emitter__, | 
					
						
							|  |  |  |                 "__event_call__": __event_call__, | 
					
						
							| 
									
										
										
										
											2024-12-14 14:51:43 +08:00
										 |  |  |                 "__request__": request, | 
					
						
							| 
									
										
										
										
											2024-12-13 12:22:17 +08:00
										 |  |  |             } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # Add extra params in contained in function signature | 
					
						
							|  |  |  |             for key, value in extra_params.items(): | 
					
						
							|  |  |  |                 if key in sig.parameters: | 
					
						
							|  |  |  |                     params[key] = value | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             if "__user__" in sig.parameters: | 
					
						
							|  |  |  |                 __user__ = { | 
					
						
							|  |  |  |                     "id": user.id, | 
					
						
							|  |  |  |                     "email": user.email, | 
					
						
							|  |  |  |                     "name": user.name, | 
					
						
							|  |  |  |                     "role": user.role, | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 try: | 
					
						
							|  |  |  |                     if hasattr(function_module, "UserValves"): | 
					
						
							|  |  |  |                         __user__["valves"] = function_module.UserValves( | 
					
						
							|  |  |  |                             **Functions.get_user_valves_by_id_and_user_id( | 
					
						
							|  |  |  |                                 action_id, user.id | 
					
						
							|  |  |  |                             ) | 
					
						
							|  |  |  |                         ) | 
					
						
							|  |  |  |                 except Exception as e: | 
					
						
							| 
									
										
										
										
											2025-02-25 22:36:25 +08:00
										 |  |  |                     log.exception(f"Failed to get user values: {e}") | 
					
						
							| 
									
										
										
										
											2024-12-13 12:22:17 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |                 params = {**params, "__user__": __user__} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             if inspect.iscoroutinefunction(action): | 
					
						
							|  |  |  |                 data = await action(**params) | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 data = action(**params) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							|  |  |  |             return Exception(f"Error: {e}") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return data |