124 lines
		
	
	
		
			4.3 KiB
		
	
	
	
		
			Python
		
	
	
	
		
		
			
		
	
	
			124 lines
		
	
	
		
			4.3 KiB
		
	
	
	
		
			Python
		
	
	
	
| 
								 | 
							
								import random
							 | 
						||
| 
								 | 
							
								import logging
							 | 
						||
| 
								 | 
							
								import sys
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								from fastapi import Request
							 | 
						||
| 
								 | 
							
								from open_webui.models.users import UserModel
							 | 
						||
| 
								 | 
							
								from open_webui.models.models import Models
							 | 
						||
| 
								 | 
							
								from open_webui.utils.models import check_model_access
							 | 
						||
| 
								 | 
							
								from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								from open_webui.routers.openai import embeddings as openai_embeddings
							 | 
						||
| 
								 | 
							
								from open_webui.routers.ollama import embeddings as ollama_embeddings
							 | 
						||
| 
								 | 
							
								from open_webui.routers.pipelines import process_pipeline_inlet_filter
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								from open_webui.utils.payload import convert_embedding_payload_openai_to_ollama
							 | 
						||
| 
								 | 
							
								from open_webui.utils.response import convert_response_ollama_to_openai
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
							 | 
						||
| 
								 | 
							
								log = logging.getLogger(__name__)
							 | 
						||
| 
								 | 
							
								log.setLevel(SRC_LOG_LEVELS["MAIN"])
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								async def generate_embeddings(
							 | 
						||
| 
								 | 
							
								    request: Request,
							 | 
						||
| 
								 | 
							
								    form_data: dict,
							 | 
						||
| 
								 | 
							
								    user: UserModel,
							 | 
						||
| 
								 | 
							
								    bypass_filter: bool = False,
							 | 
						||
| 
								 | 
							
								):
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    Dispatch and handle embeddings generation based on the model type (OpenAI, Ollama, Arena, pipeline, etc).
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    Args:
							 | 
						||
| 
								 | 
							
								        request (Request): The FastAPI request context.
							 | 
						||
| 
								 | 
							
								        form_data (dict): The input data sent to the endpoint.
							 | 
						||
| 
								 | 
							
								        user (UserModel): The authenticated user.
							 | 
						||
| 
								 | 
							
								        bypass_filter (bool): If True, disables access filtering (default False).
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    Returns:
							 | 
						||
| 
								 | 
							
								        dict: The embeddings response, following OpenAI API compatibility.
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    if BYPASS_MODEL_ACCESS_CONTROL:
							 | 
						||
| 
								 | 
							
								        bypass_filter = True
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    # Attach extra metadata from request.state if present
							 | 
						||
| 
								 | 
							
								    if hasattr(request.state, "metadata"):
							 | 
						||
| 
								 | 
							
								        if "metadata" not in form_data:
							 | 
						||
| 
								 | 
							
								            form_data["metadata"] = request.state.metadata
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            form_data["metadata"] = {
							 | 
						||
| 
								 | 
							
								                **form_data["metadata"],
							 | 
						||
| 
								 | 
							
								                **request.state.metadata,
							 | 
						||
| 
								 | 
							
								            }
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    # If "direct" flag present, use only that model
							 | 
						||
| 
								 | 
							
								    if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
							 | 
						||
| 
								 | 
							
								        models = {
							 | 
						||
| 
								 | 
							
								            request.state.model["id"]: request.state.model,
							 | 
						||
| 
								 | 
							
								        }
							 | 
						||
| 
								 | 
							
								    else:
							 | 
						||
| 
								 | 
							
								        models = request.app.state.MODELS
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    model_id = form_data.get("model")
							 | 
						||
| 
								 | 
							
								    if model_id not in models:
							 | 
						||
| 
								 | 
							
								        raise Exception("Model not found")
							 | 
						||
| 
								 | 
							
								    model = models[model_id]
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    # Access filtering
							 | 
						||
| 
								 | 
							
								    if not getattr(request.state, "direct", False):
							 | 
						||
| 
								 | 
							
								        if not bypass_filter and user.role == "user":
							 | 
						||
| 
								 | 
							
								            check_model_access(user, model)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    # Arena "meta-model": select a submodel at random
							 | 
						||
| 
								 | 
							
								    if model.get("owned_by") == "arena":
							 | 
						||
| 
								 | 
							
								        model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
							 | 
						||
| 
								 | 
							
								        filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
							 | 
						||
| 
								 | 
							
								        if model_ids and filter_mode == "exclude":
							 | 
						||
| 
								 | 
							
								            model_ids = [
							 | 
						||
| 
								 | 
							
								                m["id"]
							 | 
						||
| 
								 | 
							
								                for m in list(models.values())
							 | 
						||
| 
								 | 
							
								                if m.get("owned_by") != "arena" and m["id"] not in model_ids
							 | 
						||
| 
								 | 
							
								            ]
							 | 
						||
| 
								 | 
							
								        if isinstance(model_ids, list) and model_ids:
							 | 
						||
| 
								 | 
							
								            selected_model_id = random.choice(model_ids)
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            model_ids = [
							 | 
						||
| 
								 | 
							
								                m["id"]
							 | 
						||
| 
								 | 
							
								                for m in list(models.values())
							 | 
						||
| 
								 | 
							
								                if m.get("owned_by") != "arena"
							 | 
						||
| 
								 | 
							
								            ]
							 | 
						||
| 
								 | 
							
								            selected_model_id = random.choice(model_ids)
							 | 
						||
| 
								 | 
							
								        inner_form = dict(form_data)
							 | 
						||
| 
								 | 
							
								        inner_form["model"] = selected_model_id
							 | 
						||
| 
								 | 
							
								        response = await generate_embeddings(
							 | 
						||
| 
								 | 
							
								            request, inner_form, user, bypass_filter=True
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								        # Tag which concreted model was chosen
							 | 
						||
| 
								 | 
							
								        if isinstance(response, dict):
							 | 
						||
| 
								 | 
							
								            response = {
							 | 
						||
| 
								 | 
							
								                **response,
							 | 
						||
| 
								 | 
							
								                "selected_model_id": selected_model_id,
							 | 
						||
| 
								 | 
							
								            }
							 | 
						||
| 
								 | 
							
								        return response
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    # Pipeline/Function models
							 | 
						||
| 
								 | 
							
								    if model.get("pipe"):
							 | 
						||
| 
								 | 
							
								        # The pipeline handler should provide OpenAI-compatible schema
							 | 
						||
| 
								 | 
							
								        return await process_pipeline_inlet_filter(request, form_data, user, models)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    # Ollama backend
							 | 
						||
| 
								 | 
							
								    if model.get("owned_by") == "ollama":
							 | 
						||
| 
								 | 
							
								        ollama_payload = convert_embedding_payload_openai_to_ollama(form_data)
							 | 
						||
| 
								 | 
							
								        response = await ollama_embeddings(
							 | 
						||
| 
								 | 
							
								            request=request,
							 | 
						||
| 
								 | 
							
								            form_data=ollama_payload,
							 | 
						||
| 
								 | 
							
								            user=user,
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								        return convert_response_ollama_to_openai(response)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    # Default: OpenAI or compatible backend
							 | 
						||
| 
								 | 
							
								    return await openai_embeddings(
							 | 
						||
| 
								 | 
							
								        request=request,
							 | 
						||
| 
								 | 
							
								        form_data=form_data,
							 | 
						||
| 
								 | 
							
								        user=user,
							 | 
						||
| 
								 | 
							
								    )
							 |