refac
This commit is contained in:
		
							parent
							
								
									c7a9b5ccfa
								
							
						
					
					
						commit
						e5895af7a0
					
				
							
								
								
									
										128
									
								
								backend/main.py
								
								
								
								
							
							
						
						
									
										128
									
								
								backend/main.py
								
								
								
								
							|  | @ -212,6 +212,70 @@ origins = ["*"] | ||||||
| ################################## | ################################## | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | async def get_body_and_model_and_user(request): | ||||||
|  |     # Read the original request body | ||||||
|  |     body = await request.body() | ||||||
|  |     body_str = body.decode("utf-8") | ||||||
|  |     body = json.loads(body_str) if body_str else {} | ||||||
|  | 
 | ||||||
|  |     model_id = body["model"] | ||||||
|  |     if model_id not in app.state.MODELS: | ||||||
|  |         raise "Model not found" | ||||||
|  |     model = app.state.MODELS[model_id] | ||||||
|  | 
 | ||||||
|  |     user = get_current_user( | ||||||
|  |         request, | ||||||
|  |         get_http_authorization_cred(request.headers.get("Authorization")), | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     return body, model, user | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def get_task_model_id(default_model_id): | ||||||
|  |     # Set the task model | ||||||
|  |     task_model_id = default_model_id | ||||||
|  |     # Check if the user has a custom task model and use that model | ||||||
|  |     if app.state.MODELS[task_model_id]["owned_by"] == "ollama": | ||||||
|  |         if ( | ||||||
|  |             app.state.config.TASK_MODEL | ||||||
|  |             and app.state.config.TASK_MODEL in app.state.MODELS | ||||||
|  |         ): | ||||||
|  |             task_model_id = app.state.config.TASK_MODEL | ||||||
|  |     else: | ||||||
|  |         if ( | ||||||
|  |             app.state.config.TASK_MODEL_EXTERNAL | ||||||
|  |             and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS | ||||||
|  |         ): | ||||||
|  |             task_model_id = app.state.config.TASK_MODEL_EXTERNAL | ||||||
|  | 
 | ||||||
|  |     return task_model_id | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def get_filter_function_ids(model): | ||||||
|  |     def get_priority(function_id): | ||||||
|  |         function = Functions.get_function_by_id(function_id) | ||||||
|  |         if function is not None and hasattr(function, "valves"): | ||||||
|  |             return (function.valves if function.valves else {}).get("priority", 0) | ||||||
|  |         return 0 | ||||||
|  | 
 | ||||||
|  |     filter_ids = [function.id for function in Functions.get_global_filter_functions()] | ||||||
|  |     if "info" in model and "meta" in model["info"]: | ||||||
|  |         filter_ids.extend(model["info"]["meta"].get("filterIds", [])) | ||||||
|  |         filter_ids = list(set(filter_ids)) | ||||||
|  | 
 | ||||||
|  |     enabled_filter_ids = [ | ||||||
|  |         function.id | ||||||
|  |         for function in Functions.get_functions_by_type("filter", active_only=True) | ||||||
|  |     ] | ||||||
|  | 
 | ||||||
|  |     filter_ids = [ | ||||||
|  |         filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids | ||||||
|  |     ] | ||||||
|  | 
 | ||||||
|  |     filter_ids.sort(key=get_priority) | ||||||
|  |     return filter_ids | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| async def get_function_call_response( | async def get_function_call_response( | ||||||
|     messages, files, tool_id, template, task_model_id, user, model |     messages, files, tool_id, template, task_model_id, user, model | ||||||
| ): | ): | ||||||
|  | @ -373,51 +437,6 @@ async def get_function_call_response( | ||||||
|     return None, None, False |     return None, None, False | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def get_task_model_id(default_model_id): |  | ||||||
|     # Set the task model |  | ||||||
|     task_model_id = default_model_id |  | ||||||
|     # Check if the user has a custom task model and use that model |  | ||||||
|     if app.state.MODELS[task_model_id]["owned_by"] == "ollama": |  | ||||||
|         if ( |  | ||||||
|             app.state.config.TASK_MODEL |  | ||||||
|             and app.state.config.TASK_MODEL in app.state.MODELS |  | ||||||
|         ): |  | ||||||
|             task_model_id = app.state.config.TASK_MODEL |  | ||||||
|     else: |  | ||||||
|         if ( |  | ||||||
|             app.state.config.TASK_MODEL_EXTERNAL |  | ||||||
|             and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS |  | ||||||
|         ): |  | ||||||
|             task_model_id = app.state.config.TASK_MODEL_EXTERNAL |  | ||||||
| 
 |  | ||||||
|     return task_model_id |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def get_filter_function_ids(model): |  | ||||||
|     def get_priority(function_id): |  | ||||||
|         function = Functions.get_function_by_id(function_id) |  | ||||||
|         if function is not None and hasattr(function, "valves"): |  | ||||||
|             return (function.valves if function.valves else {}).get("priority", 0) |  | ||||||
|         return 0 |  | ||||||
| 
 |  | ||||||
|     filter_ids = [function.id for function in Functions.get_global_filter_functions()] |  | ||||||
|     if "info" in model and "meta" in model["info"]: |  | ||||||
|         filter_ids.extend(model["info"]["meta"].get("filterIds", [])) |  | ||||||
|         filter_ids = list(set(filter_ids)) |  | ||||||
| 
 |  | ||||||
|     enabled_filter_ids = [ |  | ||||||
|         function.id |  | ||||||
|         for function in Functions.get_functions_by_type("filter", active_only=True) |  | ||||||
|     ] |  | ||||||
| 
 |  | ||||||
|     filter_ids = [ |  | ||||||
|         filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids |  | ||||||
|     ] |  | ||||||
| 
 |  | ||||||
|     filter_ids.sort(key=get_priority) |  | ||||||
|     return filter_ids |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| async def chat_completion_functions_handler(body, model, user): | async def chat_completion_functions_handler(body, model, user): | ||||||
|     skip_files = None |     skip_files = None | ||||||
| 
 | 
 | ||||||
|  | @ -579,25 +598,6 @@ async def chat_completion_files_handler(body): | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| async def get_body_and_model_and_user(request): |  | ||||||
|     # Read the original request body |  | ||||||
|     body = await request.body() |  | ||||||
|     body_str = body.decode("utf-8") |  | ||||||
|     body = json.loads(body_str) if body_str else {} |  | ||||||
| 
 |  | ||||||
|     model_id = body["model"] |  | ||||||
|     if model_id not in app.state.MODELS: |  | ||||||
|         raise "Model not found" |  | ||||||
|     model = app.state.MODELS[model_id] |  | ||||||
| 
 |  | ||||||
|     user = get_current_user( |  | ||||||
|         request, |  | ||||||
|         get_http_authorization_cred(request.headers.get("Authorization")), |  | ||||||
|     ) |  | ||||||
| 
 |  | ||||||
|     return body, model, user |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class ChatCompletionMiddleware(BaseHTTPMiddleware): | class ChatCompletionMiddleware(BaseHTTPMiddleware): | ||||||
|     async def dispatch(self, request: Request, call_next): |     async def dispatch(self, request: Request, call_next): | ||||||
|         if request.method == "POST" and any( |         if request.method == "POST" and any( | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue