91 lines
		
	
	
		
			2.9 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			91 lines
		
	
	
		
			2.9 KiB
		
	
	
	
		
			Python
		
	
	
	
import random
 | 
						|
import logging
 | 
						|
import sys
 | 
						|
 | 
						|
from fastapi import Request
 | 
						|
from open_webui.models.users import UserModel
 | 
						|
from open_webui.models.models import Models
 | 
						|
from open_webui.utils.models import check_model_access
 | 
						|
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL
 | 
						|
 | 
						|
from open_webui.routers.openai import embeddings as openai_embeddings
 | 
						|
from open_webui.routers.ollama import (
 | 
						|
    embeddings as ollama_embeddings,
 | 
						|
    GenerateEmbeddingsForm,
 | 
						|
)
 | 
						|
 | 
						|
 | 
						|
from open_webui.utils.payload import convert_embedding_payload_openai_to_ollama
 | 
						|
from open_webui.utils.response import convert_embedding_response_ollama_to_openai
 | 
						|
 | 
						|
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
 | 
						|
log = logging.getLogger(__name__)
 | 
						|
log.setLevel(SRC_LOG_LEVELS["MAIN"])
 | 
						|
 | 
						|
 | 
						|
async def generate_embeddings(
 | 
						|
    request: Request,
 | 
						|
    form_data: dict,
 | 
						|
    user: UserModel,
 | 
						|
    bypass_filter: bool = False,
 | 
						|
):
 | 
						|
    """
 | 
						|
    Dispatch and handle embeddings generation based on the model type (OpenAI, Ollama).
 | 
						|
 | 
						|
    Args:
 | 
						|
        request (Request): The FastAPI request context.
 | 
						|
        form_data (dict): The input data sent to the endpoint.
 | 
						|
        user (UserModel): The authenticated user.
 | 
						|
        bypass_filter (bool): If True, disables access filtering (default False).
 | 
						|
 | 
						|
    Returns:
 | 
						|
        dict: The embeddings response, following OpenAI API compatibility.
 | 
						|
    """
 | 
						|
    if BYPASS_MODEL_ACCESS_CONTROL:
 | 
						|
        bypass_filter = True
 | 
						|
 | 
						|
    # Attach extra metadata from request.state if present
 | 
						|
    if hasattr(request.state, "metadata"):
 | 
						|
        if "metadata" not in form_data:
 | 
						|
            form_data["metadata"] = request.state.metadata
 | 
						|
        else:
 | 
						|
            form_data["metadata"] = {
 | 
						|
                **form_data["metadata"],
 | 
						|
                **request.state.metadata,
 | 
						|
            }
 | 
						|
 | 
						|
    # If "direct" flag present, use only that model
 | 
						|
    if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
 | 
						|
        models = {
 | 
						|
            request.state.model["id"]: request.state.model,
 | 
						|
        }
 | 
						|
    else:
 | 
						|
        models = request.app.state.MODELS
 | 
						|
 | 
						|
    model_id = form_data.get("model")
 | 
						|
    if model_id not in models:
 | 
						|
        raise Exception("Model not found")
 | 
						|
    model = models[model_id]
 | 
						|
 | 
						|
    # Access filtering
 | 
						|
    if not getattr(request.state, "direct", False):
 | 
						|
        if not bypass_filter and user.role == "user":
 | 
						|
            check_model_access(user, model)
 | 
						|
 | 
						|
    # Ollama backend
 | 
						|
    if model.get("owned_by") == "ollama":
 | 
						|
        ollama_payload = convert_embedding_payload_openai_to_ollama(form_data)
 | 
						|
        response = await ollama_embeddings(
 | 
						|
            request=request,
 | 
						|
            form_data=GenerateEmbeddingsForm(**ollama_payload),
 | 
						|
            user=user,
 | 
						|
        )
 | 
						|
        return convert_embedding_response_ollama_to_openai(response)
 | 
						|
 | 
						|
    # Default: OpenAI or compatible backend
 | 
						|
    return await openai_embeddings(
 | 
						|
        request=request,
 | 
						|
        form_data=form_data,
 | 
						|
        user=user,
 | 
						|
    )
 |