refac
This commit is contained in:
		
							parent
							
								
									c7a9b5ccfa
								
							
						
					
					
						commit
						e5895af7a0
					
				
							
								
								
									
										128
									
								
								backend/main.py
								
								
								
								
							
							
						
						
									
										128
									
								
								backend/main.py
								
								
								
								
							|  | @ -212,6 +212,70 @@ origins = ["*"] | |||
| ################################## | ||||
| 
 | ||||
| 
 | ||||
| async def get_body_and_model_and_user(request): | ||||
|     # Read the original request body | ||||
|     body = await request.body() | ||||
|     body_str = body.decode("utf-8") | ||||
|     body = json.loads(body_str) if body_str else {} | ||||
| 
 | ||||
|     model_id = body["model"] | ||||
|     if model_id not in app.state.MODELS: | ||||
|         raise "Model not found" | ||||
|     model = app.state.MODELS[model_id] | ||||
| 
 | ||||
|     user = get_current_user( | ||||
|         request, | ||||
|         get_http_authorization_cred(request.headers.get("Authorization")), | ||||
|     ) | ||||
| 
 | ||||
|     return body, model, user | ||||
| 
 | ||||
| 
 | ||||
| def get_task_model_id(default_model_id): | ||||
|     # Set the task model | ||||
|     task_model_id = default_model_id | ||||
|     # Check if the user has a custom task model and use that model | ||||
|     if app.state.MODELS[task_model_id]["owned_by"] == "ollama": | ||||
|         if ( | ||||
|             app.state.config.TASK_MODEL | ||||
|             and app.state.config.TASK_MODEL in app.state.MODELS | ||||
|         ): | ||||
|             task_model_id = app.state.config.TASK_MODEL | ||||
|     else: | ||||
|         if ( | ||||
|             app.state.config.TASK_MODEL_EXTERNAL | ||||
|             and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS | ||||
|         ): | ||||
|             task_model_id = app.state.config.TASK_MODEL_EXTERNAL | ||||
| 
 | ||||
|     return task_model_id | ||||
| 
 | ||||
| 
 | ||||
| def get_filter_function_ids(model): | ||||
|     def get_priority(function_id): | ||||
|         function = Functions.get_function_by_id(function_id) | ||||
|         if function is not None and hasattr(function, "valves"): | ||||
|             return (function.valves if function.valves else {}).get("priority", 0) | ||||
|         return 0 | ||||
| 
 | ||||
|     filter_ids = [function.id for function in Functions.get_global_filter_functions()] | ||||
|     if "info" in model and "meta" in model["info"]: | ||||
|         filter_ids.extend(model["info"]["meta"].get("filterIds", [])) | ||||
|         filter_ids = list(set(filter_ids)) | ||||
| 
 | ||||
|     enabled_filter_ids = [ | ||||
|         function.id | ||||
|         for function in Functions.get_functions_by_type("filter", active_only=True) | ||||
|     ] | ||||
| 
 | ||||
|     filter_ids = [ | ||||
|         filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids | ||||
|     ] | ||||
| 
 | ||||
|     filter_ids.sort(key=get_priority) | ||||
|     return filter_ids | ||||
| 
 | ||||
| 
 | ||||
| async def get_function_call_response( | ||||
|     messages, files, tool_id, template, task_model_id, user, model | ||||
| ): | ||||
|  | @ -373,51 +437,6 @@ async def get_function_call_response( | |||
|     return None, None, False | ||||
| 
 | ||||
| 
 | ||||
| def get_task_model_id(default_model_id): | ||||
|     # Set the task model | ||||
|     task_model_id = default_model_id | ||||
|     # Check if the user has a custom task model and use that model | ||||
|     if app.state.MODELS[task_model_id]["owned_by"] == "ollama": | ||||
|         if ( | ||||
|             app.state.config.TASK_MODEL | ||||
|             and app.state.config.TASK_MODEL in app.state.MODELS | ||||
|         ): | ||||
|             task_model_id = app.state.config.TASK_MODEL | ||||
|     else: | ||||
|         if ( | ||||
|             app.state.config.TASK_MODEL_EXTERNAL | ||||
|             and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS | ||||
|         ): | ||||
|             task_model_id = app.state.config.TASK_MODEL_EXTERNAL | ||||
| 
 | ||||
|     return task_model_id | ||||
| 
 | ||||
| 
 | ||||
| def get_filter_function_ids(model): | ||||
|     def get_priority(function_id): | ||||
|         function = Functions.get_function_by_id(function_id) | ||||
|         if function is not None and hasattr(function, "valves"): | ||||
|             return (function.valves if function.valves else {}).get("priority", 0) | ||||
|         return 0 | ||||
| 
 | ||||
|     filter_ids = [function.id for function in Functions.get_global_filter_functions()] | ||||
|     if "info" in model and "meta" in model["info"]: | ||||
|         filter_ids.extend(model["info"]["meta"].get("filterIds", [])) | ||||
|         filter_ids = list(set(filter_ids)) | ||||
| 
 | ||||
|     enabled_filter_ids = [ | ||||
|         function.id | ||||
|         for function in Functions.get_functions_by_type("filter", active_only=True) | ||||
|     ] | ||||
| 
 | ||||
|     filter_ids = [ | ||||
|         filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids | ||||
|     ] | ||||
| 
 | ||||
|     filter_ids.sort(key=get_priority) | ||||
|     return filter_ids | ||||
| 
 | ||||
| 
 | ||||
| async def chat_completion_functions_handler(body, model, user): | ||||
|     skip_files = None | ||||
| 
 | ||||
|  | @ -579,25 +598,6 @@ async def chat_completion_files_handler(body): | |||
|     } | ||||
| 
 | ||||
| 
 | ||||
| async def get_body_and_model_and_user(request): | ||||
|     # Read the original request body | ||||
|     body = await request.body() | ||||
|     body_str = body.decode("utf-8") | ||||
|     body = json.loads(body_str) if body_str else {} | ||||
| 
 | ||||
|     model_id = body["model"] | ||||
|     if model_id not in app.state.MODELS: | ||||
|         raise "Model not found" | ||||
|     model = app.state.MODELS[model_id] | ||||
| 
 | ||||
|     user = get_current_user( | ||||
|         request, | ||||
|         get_http_authorization_cred(request.headers.get("Authorization")), | ||||
|     ) | ||||
| 
 | ||||
|     return body, model, user | ||||
| 
 | ||||
| 
 | ||||
| class ChatCompletionMiddleware(BaseHTTPMiddleware): | ||||
|     async def dispatch(self, request: Request, call_next): | ||||
|         if request.method == "POST" and any( | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue