| 
									
										
										
										
											2025-06-04 22:09:39 +08:00
										 |  |  | import random | 
					
						
							|  |  |  | import logging | 
					
						
							|  |  |  | import sys | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from fastapi import Request | 
					
						
							|  |  |  | from open_webui.models.users import UserModel | 
					
						
							|  |  |  | from open_webui.models.models import Models | 
					
						
							|  |  |  | from open_webui.utils.models import check_model_access | 
					
						
							|  |  |  | from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from open_webui.routers.openai import embeddings as openai_embeddings | 
					
						
							| 
									
										
										
										
											2025-06-05 04:37:31 +08:00
										 |  |  | from open_webui.routers.ollama import ( | 
					
						
							|  |  |  |     embeddings as ollama_embeddings, | 
					
						
							|  |  |  |     GenerateEmbeddingsForm, | 
					
						
							|  |  |  | ) | 
					
						
							| 
									
										
										
										
											2025-06-04 22:09:39 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-06-04 23:06:38 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-06-04 22:09:39 +08:00
										 |  |  | from open_webui.utils.payload import convert_embedding_payload_openai_to_ollama | 
					
						
							| 
									
										
										
										
											2025-06-05 00:24:27 +08:00
										 |  |  | from open_webui.utils.response import convert_embedding_response_ollama_to_openai | 
					
						
							| 
									
										
										
										
											2025-06-04 22:09:39 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) | 
					
						
							|  |  |  | log = logging.getLogger(__name__) | 
					
						
							|  |  |  | log.setLevel(SRC_LOG_LEVELS["MAIN"]) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | async def generate_embeddings( | 
					
						
							|  |  |  |     request: Request, | 
					
						
							|  |  |  |     form_data: dict, | 
					
						
							|  |  |  |     user: UserModel, | 
					
						
							|  |  |  |     bypass_filter: bool = False, | 
					
						
							|  |  |  | ): | 
					
						
							|  |  |  |     """
 | 
					
						
							| 
									
										
										
										
											2025-06-05 04:37:31 +08:00
										 |  |  |     Dispatch and handle embeddings generation based on the model type (OpenAI, Ollama). | 
					
						
							| 
									
										
										
										
											2025-06-04 22:09:39 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     Args: | 
					
						
							|  |  |  |         request (Request): The FastAPI request context. | 
					
						
							|  |  |  |         form_data (dict): The input data sent to the endpoint. | 
					
						
							|  |  |  |         user (UserModel): The authenticated user. | 
					
						
							|  |  |  |         bypass_filter (bool): If True, disables access filtering (default False). | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     Returns: | 
					
						
							|  |  |  |         dict: The embeddings response, following OpenAI API compatibility. | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     if BYPASS_MODEL_ACCESS_CONTROL: | 
					
						
							|  |  |  |         bypass_filter = True | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Attach extra metadata from request.state if present | 
					
						
							|  |  |  |     if hasattr(request.state, "metadata"): | 
					
						
							|  |  |  |         if "metadata" not in form_data: | 
					
						
							|  |  |  |             form_data["metadata"] = request.state.metadata | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             form_data["metadata"] = { | 
					
						
							|  |  |  |                 **form_data["metadata"], | 
					
						
							|  |  |  |                 **request.state.metadata, | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # If "direct" flag present, use only that model | 
					
						
							|  |  |  |     if getattr(request.state, "direct", False) and hasattr(request.state, "model"): | 
					
						
							|  |  |  |         models = { | 
					
						
							|  |  |  |             request.state.model["id"]: request.state.model, | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         models = request.app.state.MODELS | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     model_id = form_data.get("model") | 
					
						
							|  |  |  |     if model_id not in models: | 
					
						
							|  |  |  |         raise Exception("Model not found") | 
					
						
							|  |  |  |     model = models[model_id] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Access filtering | 
					
						
							|  |  |  |     if not getattr(request.state, "direct", False): | 
					
						
							|  |  |  |         if not bypass_filter and user.role == "user": | 
					
						
							|  |  |  |             check_model_access(user, model) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Ollama backend | 
					
						
							|  |  |  |     if model.get("owned_by") == "ollama": | 
					
						
							|  |  |  |         ollama_payload = convert_embedding_payload_openai_to_ollama(form_data) | 
					
						
							|  |  |  |         response = await ollama_embeddings( | 
					
						
							|  |  |  |             request=request, | 
					
						
							| 
									
										
										
										
											2025-06-05 04:37:31 +08:00
										 |  |  |             form_data=GenerateEmbeddingsForm(**ollama_payload), | 
					
						
							| 
									
										
										
										
											2025-06-04 22:09:39 +08:00
										 |  |  |             user=user, | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2025-06-05 00:24:27 +08:00
										 |  |  |         return convert_embedding_response_ollama_to_openai(response) | 
					
						
							| 
									
										
										
										
											2025-06-04 22:09:39 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # Default: OpenAI or compatible backend | 
					
						
							|  |  |  |     return await openai_embeddings( | 
					
						
							|  |  |  |         request=request, | 
					
						
							|  |  |  |         form_data=form_data, | 
					
						
							|  |  |  |         user=user, | 
					
						
							| 
									
										
										
										
											2025-06-05 04:37:31 +08:00
										 |  |  |     ) |