| 
									
										
										
										
											2024-12-13 12:22:17 +08:00
										 |  |  | import time | 
					
						
							|  |  |  | import logging | 
					
						
							|  |  |  | import sys | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from aiocache import cached | 
					
						
							|  |  |  | from fastapi import Request | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from open_webui.routers import openai, ollama | 
					
						
							|  |  |  | from open_webui.functions import get_function_models | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from open_webui.models.functions import Functions | 
					
						
							|  |  |  | from open_webui.models.models import Models | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from open_webui.utils.plugin import load_function_module_by_id | 
					
						
							| 
									
										
										
										
											2024-12-13 14:28:42 +08:00
										 |  |  | from open_webui.utils.access_control import has_access | 
					
						
							| 
									
										
										
										
											2024-12-13 12:22:17 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from open_webui.config import ( | 
					
						
							|  |  |  |     DEFAULT_ARENA_MODEL, | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL | 
					
						
							| 
									
										
										
										
											2025-02-13 23:29:26 +08:00
										 |  |  | from open_webui.models.users import UserModel | 
					
						
							| 
									
										
										
										
											2024-12-13 12:22:17 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) | 
					
						
							|  |  |  | log = logging.getLogger(__name__) | 
					
						
							|  |  |  | log.setLevel(SRC_LOG_LEVELS["MAIN"]) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-14 00:12:46 +08:00
										 |  |  | async def get_all_base_models(request: Request, user: UserModel = None): | 
					
						
							| 
									
										
										
										
											2024-12-13 12:22:17 +08:00
										 |  |  |     function_models = [] | 
					
						
							|  |  |  |     openai_models = [] | 
					
						
							|  |  |  |     ollama_models = [] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if request.app.state.config.ENABLE_OPENAI_API: | 
					
						
							| 
									
										
										
										
											2025-02-13 23:29:26 +08:00
										 |  |  |         openai_models = await openai.get_all_models(request, user=user) | 
					
						
							| 
									
										
										
										
											2024-12-13 12:22:17 +08:00
										 |  |  |         openai_models = openai_models["data"] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if request.app.state.config.ENABLE_OLLAMA_API: | 
					
						
							| 
									
										
										
										
											2025-02-13 23:29:26 +08:00
										 |  |  |         ollama_models = await ollama.get_all_models(request, user=user) | 
					
						
							| 
									
										
										
										
											2024-12-13 12:22:17 +08:00
										 |  |  |         ollama_models = [ | 
					
						
							|  |  |  |             { | 
					
						
							|  |  |  |                 "id": model["model"], | 
					
						
							|  |  |  |                 "name": model["name"], | 
					
						
							|  |  |  |                 "object": "model", | 
					
						
							|  |  |  |                 "created": int(time.time()), | 
					
						
							|  |  |  |                 "owned_by": "ollama", | 
					
						
							|  |  |  |                 "ollama": model, | 
					
						
							| 
									
										
										
										
											2025-03-12 04:37:30 +08:00
										 |  |  |                 "tags": model.get("tags", []), | 
					
						
							| 
									
										
										
										
											2024-12-13 12:22:17 +08:00
										 |  |  |             } | 
					
						
							|  |  |  |             for model in ollama_models["models"] | 
					
						
							|  |  |  |         ] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-16 14:44:47 +08:00
										 |  |  |     function_models = await get_function_models(request) | 
					
						
							| 
									
										
										
										
											2024-12-13 12:22:17 +08:00
										 |  |  |     models = function_models + openai_models + ollama_models | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return models | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-14 00:12:46 +08:00
										 |  |  | async def get_all_models(request, user: UserModel = None): | 
					
						
							| 
									
										
										
										
											2025-02-13 23:29:26 +08:00
										 |  |  |     models = await get_all_base_models(request, user=user) | 
					
						
							| 
									
										
										
										
											2024-12-13 12:22:17 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # If there are no models, return an empty list | 
					
						
							|  |  |  |     if len(models) == 0: | 
					
						
							|  |  |  |         return [] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Add arena models | 
					
						
							|  |  |  |     if request.app.state.config.ENABLE_EVALUATION_ARENA_MODELS: | 
					
						
							|  |  |  |         arena_models = [] | 
					
						
							|  |  |  |         if len(request.app.state.config.EVALUATION_ARENA_MODELS) > 0: | 
					
						
							|  |  |  |             arena_models = [ | 
					
						
							|  |  |  |                 { | 
					
						
							|  |  |  |                     "id": model["id"], | 
					
						
							|  |  |  |                     "name": model["name"], | 
					
						
							|  |  |  |                     "info": { | 
					
						
							|  |  |  |                         "meta": model["meta"], | 
					
						
							|  |  |  |                     }, | 
					
						
							|  |  |  |                     "object": "model", | 
					
						
							|  |  |  |                     "created": int(time.time()), | 
					
						
							|  |  |  |                     "owned_by": "arena", | 
					
						
							|  |  |  |                     "arena": True, | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |                 for model in request.app.state.config.EVALUATION_ARENA_MODELS | 
					
						
							|  |  |  |             ] | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             # Add default arena model | 
					
						
							|  |  |  |             arena_models = [ | 
					
						
							|  |  |  |                 { | 
					
						
							|  |  |  |                     "id": DEFAULT_ARENA_MODEL["id"], | 
					
						
							|  |  |  |                     "name": DEFAULT_ARENA_MODEL["name"], | 
					
						
							|  |  |  |                     "info": { | 
					
						
							|  |  |  |                         "meta": DEFAULT_ARENA_MODEL["meta"], | 
					
						
							|  |  |  |                     }, | 
					
						
							|  |  |  |                     "object": "model", | 
					
						
							|  |  |  |                     "created": int(time.time()), | 
					
						
							|  |  |  |                     "owned_by": "arena", | 
					
						
							|  |  |  |                     "arena": True, | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |             ] | 
					
						
							|  |  |  |         models = models + arena_models | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     global_action_ids = [ | 
					
						
							|  |  |  |         function.id for function in Functions.get_global_action_functions() | 
					
						
							|  |  |  |     ] | 
					
						
							|  |  |  |     enabled_action_ids = [ | 
					
						
							|  |  |  |         function.id | 
					
						
							|  |  |  |         for function in Functions.get_functions_by_type("action", active_only=True) | 
					
						
							|  |  |  |     ] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     custom_models = Models.get_all_models() | 
					
						
							|  |  |  |     for custom_model in custom_models: | 
					
						
							|  |  |  |         if custom_model.base_model_id is None: | 
					
						
							|  |  |  |             for model in models: | 
					
						
							|  |  |  |                 if ( | 
					
						
							|  |  |  |                     custom_model.id == model["id"] | 
					
						
							|  |  |  |                     or custom_model.id == model["id"].split(":")[0] | 
					
						
							|  |  |  |                 ): | 
					
						
							|  |  |  |                     if custom_model.is_active: | 
					
						
							|  |  |  |                         model["name"] = custom_model.name | 
					
						
							|  |  |  |                         model["info"] = custom_model.model_dump() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                         action_ids = [] | 
					
						
							|  |  |  |                         if "info" in model and "meta" in model["info"]: | 
					
						
							|  |  |  |                             action_ids.extend( | 
					
						
							|  |  |  |                                 model["info"]["meta"].get("actionIds", []) | 
					
						
							|  |  |  |                             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                         model["action_ids"] = action_ids | 
					
						
							|  |  |  |                     else: | 
					
						
							|  |  |  |                         models.remove(model) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         elif custom_model.is_active and ( | 
					
						
							|  |  |  |             custom_model.id not in [model["id"] for model in models] | 
					
						
							|  |  |  |         ): | 
					
						
							|  |  |  |             owned_by = "openai" | 
					
						
							|  |  |  |             pipe = None | 
					
						
							|  |  |  |             action_ids = [] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             for model in models: | 
					
						
							|  |  |  |                 if ( | 
					
						
							|  |  |  |                     custom_model.base_model_id == model["id"] | 
					
						
							|  |  |  |                     or custom_model.base_model_id == model["id"].split(":")[0] | 
					
						
							|  |  |  |                 ): | 
					
						
							| 
									
										
										
										
											2025-02-16 08:42:58 +08:00
										 |  |  |                     owned_by = model.get("owned_by", "unknown owner") | 
					
						
							| 
									
										
										
										
											2024-12-13 12:22:17 +08:00
										 |  |  |                     if "pipe" in model: | 
					
						
							|  |  |  |                         pipe = model["pipe"] | 
					
						
							|  |  |  |                     break | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             if custom_model.meta: | 
					
						
							|  |  |  |                 meta = custom_model.meta.model_dump() | 
					
						
							|  |  |  |                 if "actionIds" in meta: | 
					
						
							|  |  |  |                     action_ids.extend(meta["actionIds"]) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             models.append( | 
					
						
							|  |  |  |                 { | 
					
						
							|  |  |  |                     "id": f"{custom_model.id}", | 
					
						
							|  |  |  |                     "name": custom_model.name, | 
					
						
							|  |  |  |                     "object": "model", | 
					
						
							|  |  |  |                     "created": custom_model.created_at, | 
					
						
							|  |  |  |                     "owned_by": owned_by, | 
					
						
							|  |  |  |                     "info": custom_model.model_dump(), | 
					
						
							|  |  |  |                     "preset": True, | 
					
						
							|  |  |  |                     **({"pipe": pipe} if pipe is not None else {}), | 
					
						
							|  |  |  |                     "action_ids": action_ids, | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Process action_ids to get the actions | 
					
						
							|  |  |  |     def get_action_items_from_module(function, module): | 
					
						
							|  |  |  |         actions = [] | 
					
						
							|  |  |  |         if hasattr(module, "actions"): | 
					
						
							|  |  |  |             actions = module.actions | 
					
						
							|  |  |  |             return [ | 
					
						
							|  |  |  |                 { | 
					
						
							|  |  |  |                     "id": f"{function.id}.{action['id']}", | 
					
						
							|  |  |  |                     "name": action.get("name", f"{function.name} ({action['id']})"), | 
					
						
							|  |  |  |                     "description": function.meta.description, | 
					
						
							|  |  |  |                     "icon_url": action.get( | 
					
						
							|  |  |  |                         "icon_url", function.meta.manifest.get("icon_url", None) | 
					
						
							|  |  |  |                     ), | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |                 for action in actions | 
					
						
							|  |  |  |             ] | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             return [ | 
					
						
							|  |  |  |                 { | 
					
						
							|  |  |  |                     "id": function.id, | 
					
						
							|  |  |  |                     "name": function.name, | 
					
						
							|  |  |  |                     "description": function.meta.description, | 
					
						
							|  |  |  |                     "icon_url": function.meta.manifest.get("icon_url", None), | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |             ] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def get_function_module_by_id(function_id): | 
					
						
							|  |  |  |         if function_id in request.app.state.FUNCTIONS: | 
					
						
							|  |  |  |             function_module = request.app.state.FUNCTIONS[function_id] | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             function_module, _, _ = load_function_module_by_id(function_id) | 
					
						
							|  |  |  |             request.app.state.FUNCTIONS[function_id] = function_module | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     for model in models: | 
					
						
							|  |  |  |         action_ids = [ | 
					
						
							|  |  |  |             action_id | 
					
						
							|  |  |  |             for action_id in list(set(model.pop("action_ids", []) + global_action_ids)) | 
					
						
							|  |  |  |             if action_id in enabled_action_ids | 
					
						
							|  |  |  |         ] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         model["actions"] = [] | 
					
						
							|  |  |  |         for action_id in action_ids: | 
					
						
							|  |  |  |             action_function = Functions.get_function_by_id(action_id) | 
					
						
							|  |  |  |             if action_function is None: | 
					
						
							|  |  |  |                 raise Exception(f"Action not found: {action_id}") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             function_module = get_function_module_by_id(action_id) | 
					
						
							|  |  |  |             model["actions"].extend( | 
					
						
							|  |  |  |                 get_action_items_from_module(action_function, function_module) | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |     log.debug(f"get_all_models() returned {len(models)} models") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     request.app.state.MODELS = {model["id"]: model for model in models} | 
					
						
							|  |  |  |     return models | 
					
						
							| 
									
										
										
										
											2024-12-13 14:28:42 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def check_model_access(user, model): | 
					
						
							|  |  |  |     if model.get("arena"): | 
					
						
							|  |  |  |         if not has_access( | 
					
						
							|  |  |  |             user.id, | 
					
						
							|  |  |  |             type="read", | 
					
						
							|  |  |  |             access_control=model.get("info", {}) | 
					
						
							|  |  |  |             .get("meta", {}) | 
					
						
							|  |  |  |             .get("access_control", {}), | 
					
						
							|  |  |  |         ): | 
					
						
							|  |  |  |             raise Exception("Model not found") | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         model_info = Models.get_model_by_id(model.get("id")) | 
					
						
							|  |  |  |         if not model_info: | 
					
						
							|  |  |  |             raise Exception("Model not found") | 
					
						
							|  |  |  |         elif not ( | 
					
						
							|  |  |  |             user.id == model_info.user_id | 
					
						
							|  |  |  |             or has_access( | 
					
						
							|  |  |  |                 user.id, type="read", access_control=model_info.access_control | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |         ): | 
					
						
							|  |  |  |             raise Exception("Model not found") |