1276 lines
		
	
	
		
			42 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			1276 lines
		
	
	
		
			42 KiB
		
	
	
	
		
			Python
		
	
	
	
from fastapi import (
 | 
						|
    FastAPI,
 | 
						|
    Depends,
 | 
						|
    HTTPException,
 | 
						|
    status,
 | 
						|
    UploadFile,
 | 
						|
    File,
 | 
						|
    Form,
 | 
						|
)
 | 
						|
from fastapi.middleware.cors import CORSMiddleware
 | 
						|
import os, shutil, logging, re
 | 
						|
from datetime import datetime
 | 
						|
 | 
						|
from pathlib import Path
 | 
						|
from typing import List, Union, Sequence, Iterator, Any
 | 
						|
 | 
						|
from chromadb.utils.batch_utils import create_batches
 | 
						|
from langchain_core.documents import Document
 | 
						|
 | 
						|
from langchain_community.document_loaders import (
 | 
						|
    WebBaseLoader,
 | 
						|
    TextLoader,
 | 
						|
    PyPDFLoader,
 | 
						|
    CSVLoader,
 | 
						|
    BSHTMLLoader,
 | 
						|
    Docx2txtLoader,
 | 
						|
    UnstructuredEPubLoader,
 | 
						|
    UnstructuredWordDocumentLoader,
 | 
						|
    UnstructuredMarkdownLoader,
 | 
						|
    UnstructuredXMLLoader,
 | 
						|
    UnstructuredRSTLoader,
 | 
						|
    UnstructuredExcelLoader,
 | 
						|
    UnstructuredPowerPointLoader,
 | 
						|
    YoutubeLoader,
 | 
						|
    OutlookMessageLoader,
 | 
						|
)
 | 
						|
from langchain.text_splitter import RecursiveCharacterTextSplitter
 | 
						|
 | 
						|
import validators
 | 
						|
import urllib.parse
 | 
						|
import socket
 | 
						|
 | 
						|
 | 
						|
from pydantic import BaseModel
 | 
						|
from typing import Optional
 | 
						|
import mimetypes
 | 
						|
import uuid
 | 
						|
import json
 | 
						|
 | 
						|
import sentence_transformers
 | 
						|
 | 
						|
from apps.webui.models.documents import (
 | 
						|
    Documents,
 | 
						|
    DocumentForm,
 | 
						|
    DocumentResponse,
 | 
						|
)
 | 
						|
 | 
						|
from apps.rag.utils import (
 | 
						|
    get_model_path,
 | 
						|
    get_embedding_function,
 | 
						|
    query_doc,
 | 
						|
    query_doc_with_hybrid_search,
 | 
						|
    query_collection,
 | 
						|
    query_collection_with_hybrid_search,
 | 
						|
)
 | 
						|
 | 
						|
from apps.rag.search.brave import search_brave
 | 
						|
from apps.rag.search.google_pse import search_google_pse
 | 
						|
from apps.rag.search.main import SearchResult
 | 
						|
from apps.rag.search.searxng import search_searxng
 | 
						|
from apps.rag.search.serper import search_serper
 | 
						|
from apps.rag.search.serpstack import search_serpstack
 | 
						|
from apps.rag.search.serply import search_serply
 | 
						|
from apps.rag.search.duckduckgo import search_duckduckgo
 | 
						|
 | 
						|
from utils.misc import (
 | 
						|
    calculate_sha256,
 | 
						|
    calculate_sha256_string,
 | 
						|
    sanitize_filename,
 | 
						|
    extract_folders_after_data_docs,
 | 
						|
)
 | 
						|
from utils.utils import get_current_user, get_admin_user
 | 
						|
 | 
						|
from config import (
 | 
						|
    AppConfig,
 | 
						|
    ENV,
 | 
						|
    SRC_LOG_LEVELS,
 | 
						|
    UPLOAD_DIR,
 | 
						|
    DOCS_DIR,
 | 
						|
    RAG_TOP_K,
 | 
						|
    RAG_RELEVANCE_THRESHOLD,
 | 
						|
    RAG_EMBEDDING_ENGINE,
 | 
						|
    RAG_EMBEDDING_MODEL,
 | 
						|
    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
 | 
						|
    RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
 | 
						|
    ENABLE_RAG_HYBRID_SEARCH,
 | 
						|
    ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
 | 
						|
    RAG_RERANKING_MODEL,
 | 
						|
    PDF_EXTRACT_IMAGES,
 | 
						|
    RAG_RERANKING_MODEL_AUTO_UPDATE,
 | 
						|
    RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
 | 
						|
    RAG_OPENAI_API_BASE_URL,
 | 
						|
    RAG_OPENAI_API_KEY,
 | 
						|
    DEVICE_TYPE,
 | 
						|
    CHROMA_CLIENT,
 | 
						|
    CHUNK_SIZE,
 | 
						|
    CHUNK_OVERLAP,
 | 
						|
    RAG_TEMPLATE,
 | 
						|
    ENABLE_RAG_LOCAL_WEB_FETCH,
 | 
						|
    YOUTUBE_LOADER_LANGUAGE,
 | 
						|
    ENABLE_RAG_WEB_SEARCH,
 | 
						|
    RAG_WEB_SEARCH_ENGINE,
 | 
						|
    SEARXNG_QUERY_URL,
 | 
						|
    GOOGLE_PSE_API_KEY,
 | 
						|
    GOOGLE_PSE_ENGINE_ID,
 | 
						|
    BRAVE_SEARCH_API_KEY,
 | 
						|
    SERPSTACK_API_KEY,
 | 
						|
    SERPSTACK_HTTPS,
 | 
						|
    SERPER_API_KEY,
 | 
						|
    SERPLY_API_KEY,
 | 
						|
    RAG_WEB_SEARCH_RESULT_COUNT,
 | 
						|
    RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
 | 
						|
    RAG_EMBEDDING_OPENAI_BATCH_SIZE,
 | 
						|
)
 | 
						|
 | 
						|
from constants import ERROR_MESSAGES
 | 
						|
 | 
						|
log = logging.getLogger(__name__)
 | 
						|
log.setLevel(SRC_LOG_LEVELS["RAG"])
 | 
						|
 | 
						|
app = FastAPI()
 | 
						|
 | 
						|
app.state.config = AppConfig()
 | 
						|
 | 
						|
app.state.config.TOP_K = RAG_TOP_K
 | 
						|
app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
 | 
						|
 | 
						|
app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
 | 
						|
app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
 | 
						|
    ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
 | 
						|
)
 | 
						|
 | 
						|
app.state.config.CHUNK_SIZE = CHUNK_SIZE
 | 
						|
app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
 | 
						|
 | 
						|
app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
 | 
						|
app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
 | 
						|
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = RAG_EMBEDDING_OPENAI_BATCH_SIZE
 | 
						|
app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
 | 
						|
app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
 | 
						|
 | 
						|
 | 
						|
app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
 | 
						|
app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY
 | 
						|
 | 
						|
app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
 | 
						|
 | 
						|
 | 
						|
app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
 | 
						|
app.state.YOUTUBE_LOADER_TRANSLATION = None
 | 
						|
 | 
						|
 | 
						|
app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH
 | 
						|
app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE
 | 
						|
 | 
						|
app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL
 | 
						|
app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY
 | 
						|
app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID
 | 
						|
app.state.config.BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY
 | 
						|
app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY
 | 
						|
app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS
 | 
						|
app.state.config.SERPER_API_KEY = SERPER_API_KEY
 | 
						|
app.state.config.SERPLY_API_KEY = SERPLY_API_KEY
 | 
						|
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
 | 
						|
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
 | 
						|
 | 
						|
 | 
						|
def update_embedding_model(
 | 
						|
    embedding_model: str,
 | 
						|
    update_model: bool = False,
 | 
						|
):
 | 
						|
    if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "":
 | 
						|
        app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
 | 
						|
            get_model_path(embedding_model, update_model),
 | 
						|
            device=DEVICE_TYPE,
 | 
						|
            trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
 | 
						|
        )
 | 
						|
    else:
 | 
						|
        app.state.sentence_transformer_ef = None
 | 
						|
 | 
						|
 | 
						|
def update_reranking_model(
 | 
						|
    reranking_model: str,
 | 
						|
    update_model: bool = False,
 | 
						|
):
 | 
						|
    if reranking_model:
 | 
						|
        app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
 | 
						|
            get_model_path(reranking_model, update_model),
 | 
						|
            device=DEVICE_TYPE,
 | 
						|
            trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
 | 
						|
        )
 | 
						|
    else:
 | 
						|
        app.state.sentence_transformer_rf = None
 | 
						|
 | 
						|
 | 
						|
update_embedding_model(
 | 
						|
    app.state.config.RAG_EMBEDDING_MODEL,
 | 
						|
    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
 | 
						|
)
 | 
						|
 | 
						|
update_reranking_model(
 | 
						|
    app.state.config.RAG_RERANKING_MODEL,
 | 
						|
    RAG_RERANKING_MODEL_AUTO_UPDATE,
 | 
						|
)
 | 
						|
 | 
						|
 | 
						|
app.state.EMBEDDING_FUNCTION = get_embedding_function(
 | 
						|
    app.state.config.RAG_EMBEDDING_ENGINE,
 | 
						|
    app.state.config.RAG_EMBEDDING_MODEL,
 | 
						|
    app.state.sentence_transformer_ef,
 | 
						|
    app.state.config.OPENAI_API_KEY,
 | 
						|
    app.state.config.OPENAI_API_BASE_URL,
 | 
						|
    app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
 | 
						|
)
 | 
						|
 | 
						|
origins = ["*"]
 | 
						|
 | 
						|
 | 
						|
app.add_middleware(
 | 
						|
    CORSMiddleware,
 | 
						|
    allow_origins=origins,
 | 
						|
    allow_credentials=True,
 | 
						|
    allow_methods=["*"],
 | 
						|
    allow_headers=["*"],
 | 
						|
)
 | 
						|
 | 
						|
 | 
						|
class CollectionNameForm(BaseModel):
 | 
						|
    collection_name: Optional[str] = "test"
 | 
						|
 | 
						|
 | 
						|
class UrlForm(CollectionNameForm):
 | 
						|
    url: str
 | 
						|
 | 
						|
 | 
						|
class SearchForm(CollectionNameForm):
 | 
						|
    query: str
 | 
						|
 | 
						|
 | 
						|
@app.get("/")
 | 
						|
async def get_status():
 | 
						|
    return {
 | 
						|
        "status": True,
 | 
						|
        "chunk_size": app.state.config.CHUNK_SIZE,
 | 
						|
        "chunk_overlap": app.state.config.CHUNK_OVERLAP,
 | 
						|
        "template": app.state.config.RAG_TEMPLATE,
 | 
						|
        "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
 | 
						|
        "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
 | 
						|
        "reranking_model": app.state.config.RAG_RERANKING_MODEL,
 | 
						|
        "openai_batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
 | 
						|
    }
 | 
						|
 | 
						|
 | 
						|
@app.get("/embedding")
 | 
						|
async def get_embedding_config(user=Depends(get_admin_user)):
 | 
						|
    return {
 | 
						|
        "status": True,
 | 
						|
        "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
 | 
						|
        "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
 | 
						|
        "openai_config": {
 | 
						|
            "url": app.state.config.OPENAI_API_BASE_URL,
 | 
						|
            "key": app.state.config.OPENAI_API_KEY,
 | 
						|
            "batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
 | 
						|
        },
 | 
						|
    }
 | 
						|
 | 
						|
 | 
						|
@app.get("/reranking")
 | 
						|
async def get_reraanking_config(user=Depends(get_admin_user)):
 | 
						|
    return {
 | 
						|
        "status": True,
 | 
						|
        "reranking_model": app.state.config.RAG_RERANKING_MODEL,
 | 
						|
    }
 | 
						|
 | 
						|
 | 
						|
class OpenAIConfigForm(BaseModel):
 | 
						|
    url: str
 | 
						|
    key: str
 | 
						|
    batch_size: Optional[int] = None
 | 
						|
 | 
						|
 | 
						|
class EmbeddingModelUpdateForm(BaseModel):
 | 
						|
    openai_config: Optional[OpenAIConfigForm] = None
 | 
						|
    embedding_engine: str
 | 
						|
    embedding_model: str
 | 
						|
 | 
						|
 | 
						|
@app.post("/embedding/update")
 | 
						|
async def update_embedding_config(
 | 
						|
    form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
 | 
						|
):
 | 
						|
    log.info(
 | 
						|
        f"Updating embedding model: {app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
 | 
						|
    )
 | 
						|
    try:
 | 
						|
        app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
 | 
						|
        app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
 | 
						|
 | 
						|
        if app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
 | 
						|
            if form_data.openai_config is not None:
 | 
						|
                app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
 | 
						|
                app.state.config.OPENAI_API_KEY = form_data.openai_config.key
 | 
						|
                app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = (
 | 
						|
                    form_data.openai_config.batch_size
 | 
						|
                    if form_data.openai_config.batch_size
 | 
						|
                    else 1
 | 
						|
                )
 | 
						|
 | 
						|
        update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
 | 
						|
 | 
						|
        app.state.EMBEDDING_FUNCTION = get_embedding_function(
 | 
						|
            app.state.config.RAG_EMBEDDING_ENGINE,
 | 
						|
            app.state.config.RAG_EMBEDDING_MODEL,
 | 
						|
            app.state.sentence_transformer_ef,
 | 
						|
            app.state.config.OPENAI_API_KEY,
 | 
						|
            app.state.config.OPENAI_API_BASE_URL,
 | 
						|
            app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
 | 
						|
        )
 | 
						|
 | 
						|
        return {
 | 
						|
            "status": True,
 | 
						|
            "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
 | 
						|
            "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
 | 
						|
            "openai_config": {
 | 
						|
                "url": app.state.config.OPENAI_API_BASE_URL,
 | 
						|
                "key": app.state.config.OPENAI_API_KEY,
 | 
						|
                "batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
 | 
						|
            },
 | 
						|
        }
 | 
						|
    except Exception as e:
 | 
						|
        log.exception(f"Problem updating embedding model: {e}")
 | 
						|
        raise HTTPException(
 | 
						|
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
 | 
						|
            detail=ERROR_MESSAGES.DEFAULT(e),
 | 
						|
        )
 | 
						|
 | 
						|
 | 
						|
class RerankingModelUpdateForm(BaseModel):
 | 
						|
    reranking_model: str
 | 
						|
 | 
						|
 | 
						|
@app.post("/reranking/update")
 | 
						|
async def update_reranking_config(
 | 
						|
    form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
 | 
						|
):
 | 
						|
    log.info(
 | 
						|
        f"Updating reranking model: {app.state.config.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
 | 
						|
    )
 | 
						|
    try:
 | 
						|
        app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model
 | 
						|
 | 
						|
        update_reranking_model(app.state.config.RAG_RERANKING_MODEL), True
 | 
						|
 | 
						|
        return {
 | 
						|
            "status": True,
 | 
						|
            "reranking_model": app.state.config.RAG_RERANKING_MODEL,
 | 
						|
        }
 | 
						|
    except Exception as e:
 | 
						|
        log.exception(f"Problem updating reranking model: {e}")
 | 
						|
        raise HTTPException(
 | 
						|
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
 | 
						|
            detail=ERROR_MESSAGES.DEFAULT(e),
 | 
						|
        )
 | 
						|
 | 
						|
 | 
						|
@app.get("/config")
 | 
						|
async def get_rag_config(user=Depends(get_admin_user)):
 | 
						|
    return {
 | 
						|
        "status": True,
 | 
						|
        "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
 | 
						|
        "chunk": {
 | 
						|
            "chunk_size": app.state.config.CHUNK_SIZE,
 | 
						|
            "chunk_overlap": app.state.config.CHUNK_OVERLAP,
 | 
						|
        },
 | 
						|
        "youtube": {
 | 
						|
            "language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
 | 
						|
            "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
 | 
						|
        },
 | 
						|
        "web": {
 | 
						|
            "ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
 | 
						|
            "search": {
 | 
						|
                "enabled": app.state.config.ENABLE_RAG_WEB_SEARCH,
 | 
						|
                "engine": app.state.config.RAG_WEB_SEARCH_ENGINE,
 | 
						|
                "searxng_query_url": app.state.config.SEARXNG_QUERY_URL,
 | 
						|
                "google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY,
 | 
						|
                "google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID,
 | 
						|
                "brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY,
 | 
						|
                "serpstack_api_key": app.state.config.SERPSTACK_API_KEY,
 | 
						|
                "serpstack_https": app.state.config.SERPSTACK_HTTPS,
 | 
						|
                "serper_api_key": app.state.config.SERPER_API_KEY,
 | 
						|
                "serply_api_key": app.state.config.SERPLY_API_KEY,
 | 
						|
                "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
 | 
						|
                "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
 | 
						|
            },
 | 
						|
        },
 | 
						|
    }
 | 
						|
 | 
						|
 | 
						|
class ChunkParamUpdateForm(BaseModel):
 | 
						|
    chunk_size: int
 | 
						|
    chunk_overlap: int
 | 
						|
 | 
						|
 | 
						|
class YoutubeLoaderConfig(BaseModel):
 | 
						|
    language: List[str]
 | 
						|
    translation: Optional[str] = None
 | 
						|
 | 
						|
 | 
						|
class WebSearchConfig(BaseModel):
 | 
						|
    enabled: bool
 | 
						|
    engine: Optional[str] = None
 | 
						|
    searxng_query_url: Optional[str] = None
 | 
						|
    google_pse_api_key: Optional[str] = None
 | 
						|
    google_pse_engine_id: Optional[str] = None
 | 
						|
    brave_search_api_key: Optional[str] = None
 | 
						|
    serpstack_api_key: Optional[str] = None
 | 
						|
    serpstack_https: Optional[bool] = None
 | 
						|
    serper_api_key: Optional[str] = None
 | 
						|
    serply_api_key: Optional[str] = None
 | 
						|
    result_count: Optional[int] = None
 | 
						|
    concurrent_requests: Optional[int] = None
 | 
						|
 | 
						|
 | 
						|
class WebConfig(BaseModel):
 | 
						|
    search: WebSearchConfig
 | 
						|
    web_loader_ssl_verification: Optional[bool] = None
 | 
						|
 | 
						|
 | 
						|
class ConfigUpdateForm(BaseModel):
 | 
						|
    pdf_extract_images: Optional[bool] = None
 | 
						|
    chunk: Optional[ChunkParamUpdateForm] = None
 | 
						|
    youtube: Optional[YoutubeLoaderConfig] = None
 | 
						|
    web: Optional[WebConfig] = None
 | 
						|
 | 
						|
 | 
						|
@app.post("/config/update")
 | 
						|
async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
 | 
						|
    app.state.config.PDF_EXTRACT_IMAGES = (
 | 
						|
        form_data.pdf_extract_images
 | 
						|
        if form_data.pdf_extract_images is not None
 | 
						|
        else app.state.config.PDF_EXTRACT_IMAGES
 | 
						|
    )
 | 
						|
 | 
						|
    if form_data.chunk is not None:
 | 
						|
        app.state.config.CHUNK_SIZE = form_data.chunk.chunk_size
 | 
						|
        app.state.config.CHUNK_OVERLAP = form_data.chunk.chunk_overlap
 | 
						|
 | 
						|
    if form_data.youtube is not None:
 | 
						|
        app.state.config.YOUTUBE_LOADER_LANGUAGE = form_data.youtube.language
 | 
						|
        app.state.YOUTUBE_LOADER_TRANSLATION = form_data.youtube.translation
 | 
						|
 | 
						|
    if form_data.web is not None:
 | 
						|
        app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
 | 
						|
            form_data.web.web_loader_ssl_verification
 | 
						|
        )
 | 
						|
 | 
						|
        app.state.config.ENABLE_RAG_WEB_SEARCH = form_data.web.search.enabled
 | 
						|
        app.state.config.RAG_WEB_SEARCH_ENGINE = form_data.web.search.engine
 | 
						|
        app.state.config.SEARXNG_QUERY_URL = form_data.web.search.searxng_query_url
 | 
						|
        app.state.config.GOOGLE_PSE_API_KEY = form_data.web.search.google_pse_api_key
 | 
						|
        app.state.config.GOOGLE_PSE_ENGINE_ID = (
 | 
						|
            form_data.web.search.google_pse_engine_id
 | 
						|
        )
 | 
						|
        app.state.config.BRAVE_SEARCH_API_KEY = (
 | 
						|
            form_data.web.search.brave_search_api_key
 | 
						|
        )
 | 
						|
        app.state.config.SERPSTACK_API_KEY = form_data.web.search.serpstack_api_key
 | 
						|
        app.state.config.SERPSTACK_HTTPS = form_data.web.search.serpstack_https
 | 
						|
        app.state.config.SERPER_API_KEY = form_data.web.search.serper_api_key
 | 
						|
        app.state.config.SERPLY_API_KEY = form_data.web.search.serply_api_key
 | 
						|
        app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = form_data.web.search.result_count
 | 
						|
        app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = (
 | 
						|
            form_data.web.search.concurrent_requests
 | 
						|
        )
 | 
						|
 | 
						|
    return {
 | 
						|
        "status": True,
 | 
						|
        "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
 | 
						|
        "chunk": {
 | 
						|
            "chunk_size": app.state.config.CHUNK_SIZE,
 | 
						|
            "chunk_overlap": app.state.config.CHUNK_OVERLAP,
 | 
						|
        },
 | 
						|
        "youtube": {
 | 
						|
            "language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
 | 
						|
            "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
 | 
						|
        },
 | 
						|
        "web": {
 | 
						|
            "ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
 | 
						|
            "search": {
 | 
						|
                "enabled": app.state.config.ENABLE_RAG_WEB_SEARCH,
 | 
						|
                "engine": app.state.config.RAG_WEB_SEARCH_ENGINE,
 | 
						|
                "searxng_query_url": app.state.config.SEARXNG_QUERY_URL,
 | 
						|
                "google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY,
 | 
						|
                "google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID,
 | 
						|
                "brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY,
 | 
						|
                "serpstack_api_key": app.state.config.SERPSTACK_API_KEY,
 | 
						|
                "serpstack_https": app.state.config.SERPSTACK_HTTPS,
 | 
						|
                "serper_api_key": app.state.config.SERPER_API_KEY,
 | 
						|
                "serply_api_key": app.state.config.SERPLY_API_KEY,
 | 
						|
                "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
 | 
						|
                "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
 | 
						|
            },
 | 
						|
        },
 | 
						|
    }
 | 
						|
 | 
						|
 | 
						|
@app.get("/template")
 | 
						|
async def get_rag_template(user=Depends(get_current_user)):
 | 
						|
    return {
 | 
						|
        "status": True,
 | 
						|
        "template": app.state.config.RAG_TEMPLATE,
 | 
						|
    }
 | 
						|
 | 
						|
 | 
						|
@app.get("/query/settings")
 | 
						|
async def get_query_settings(user=Depends(get_admin_user)):
 | 
						|
    return {
 | 
						|
        "status": True,
 | 
						|
        "template": app.state.config.RAG_TEMPLATE,
 | 
						|
        "k": app.state.config.TOP_K,
 | 
						|
        "r": app.state.config.RELEVANCE_THRESHOLD,
 | 
						|
        "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
 | 
						|
    }
 | 
						|
 | 
						|
 | 
						|
class QuerySettingsForm(BaseModel):
 | 
						|
    k: Optional[int] = None
 | 
						|
    r: Optional[float] = None
 | 
						|
    template: Optional[str] = None
 | 
						|
    hybrid: Optional[bool] = None
 | 
						|
 | 
						|
 | 
						|
@app.post("/query/settings/update")
 | 
						|
async def update_query_settings(
 | 
						|
    form_data: QuerySettingsForm, user=Depends(get_admin_user)
 | 
						|
):
 | 
						|
    app.state.config.RAG_TEMPLATE = (
 | 
						|
        form_data.template if form_data.template else RAG_TEMPLATE
 | 
						|
    )
 | 
						|
    app.state.config.TOP_K = form_data.k if form_data.k else 4
 | 
						|
    app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
 | 
						|
    app.state.config.ENABLE_RAG_HYBRID_SEARCH = (
 | 
						|
        form_data.hybrid if form_data.hybrid else False
 | 
						|
    )
 | 
						|
    return {
 | 
						|
        "status": True,
 | 
						|
        "template": app.state.config.RAG_TEMPLATE,
 | 
						|
        "k": app.state.config.TOP_K,
 | 
						|
        "r": app.state.config.RELEVANCE_THRESHOLD,
 | 
						|
        "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
 | 
						|
    }
 | 
						|
 | 
						|
 | 
						|
class QueryDocForm(BaseModel):
 | 
						|
    collection_name: str
 | 
						|
    query: str
 | 
						|
    k: Optional[int] = None
 | 
						|
    r: Optional[float] = None
 | 
						|
    hybrid: Optional[bool] = None
 | 
						|
 | 
						|
 | 
						|
@app.post("/query/doc")
 | 
						|
def query_doc_handler(
 | 
						|
    form_data: QueryDocForm,
 | 
						|
    user=Depends(get_current_user),
 | 
						|
):
 | 
						|
    try:
 | 
						|
        if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
 | 
						|
            return query_doc_with_hybrid_search(
 | 
						|
                collection_name=form_data.collection_name,
 | 
						|
                query=form_data.query,
 | 
						|
                embedding_function=app.state.EMBEDDING_FUNCTION,
 | 
						|
                k=form_data.k if form_data.k else app.state.config.TOP_K,
 | 
						|
                reranking_function=app.state.sentence_transformer_rf,
 | 
						|
                r=(
 | 
						|
                    form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
 | 
						|
                ),
 | 
						|
            )
 | 
						|
        else:
 | 
						|
            return query_doc(
 | 
						|
                collection_name=form_data.collection_name,
 | 
						|
                query=form_data.query,
 | 
						|
                embedding_function=app.state.EMBEDDING_FUNCTION,
 | 
						|
                k=form_data.k if form_data.k else app.state.config.TOP_K,
 | 
						|
            )
 | 
						|
    except Exception as e:
 | 
						|
        log.exception(e)
 | 
						|
        raise HTTPException(
 | 
						|
            status_code=status.HTTP_400_BAD_REQUEST,
 | 
						|
            detail=ERROR_MESSAGES.DEFAULT(e),
 | 
						|
        )
 | 
						|
 | 
						|
 | 
						|
class QueryCollectionsForm(BaseModel):
 | 
						|
    collection_names: List[str]
 | 
						|
    query: str
 | 
						|
    k: Optional[int] = None
 | 
						|
    r: Optional[float] = None
 | 
						|
    hybrid: Optional[bool] = None
 | 
						|
 | 
						|
 | 
						|
@app.post("/query/collection")
 | 
						|
def query_collection_handler(
 | 
						|
    form_data: QueryCollectionsForm,
 | 
						|
    user=Depends(get_current_user),
 | 
						|
):
 | 
						|
    try:
 | 
						|
        if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
 | 
						|
            return query_collection_with_hybrid_search(
 | 
						|
                collection_names=form_data.collection_names,
 | 
						|
                query=form_data.query,
 | 
						|
                embedding_function=app.state.EMBEDDING_FUNCTION,
 | 
						|
                k=form_data.k if form_data.k else app.state.config.TOP_K,
 | 
						|
                reranking_function=app.state.sentence_transformer_rf,
 | 
						|
                r=(
 | 
						|
                    form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
 | 
						|
                ),
 | 
						|
            )
 | 
						|
        else:
 | 
						|
            return query_collection(
 | 
						|
                collection_names=form_data.collection_names,
 | 
						|
                query=form_data.query,
 | 
						|
                embedding_function=app.state.EMBEDDING_FUNCTION,
 | 
						|
                k=form_data.k if form_data.k else app.state.config.TOP_K,
 | 
						|
            )
 | 
						|
 | 
						|
    except Exception as e:
 | 
						|
        log.exception(e)
 | 
						|
        raise HTTPException(
 | 
						|
            status_code=status.HTTP_400_BAD_REQUEST,
 | 
						|
            detail=ERROR_MESSAGES.DEFAULT(e),
 | 
						|
        )
 | 
						|
 | 
						|
 | 
						|
@app.post("/youtube")
 | 
						|
def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)):
 | 
						|
    try:
 | 
						|
        loader = YoutubeLoader.from_youtube_url(
 | 
						|
            form_data.url,
 | 
						|
            add_video_info=True,
 | 
						|
            language=app.state.config.YOUTUBE_LOADER_LANGUAGE,
 | 
						|
            translation=app.state.YOUTUBE_LOADER_TRANSLATION,
 | 
						|
        )
 | 
						|
        data = loader.load()
 | 
						|
 | 
						|
        collection_name = form_data.collection_name
 | 
						|
        if collection_name == "":
 | 
						|
            collection_name = calculate_sha256_string(form_data.url)[:63]
 | 
						|
 | 
						|
        store_data_in_vector_db(data, collection_name, overwrite=True)
 | 
						|
        return {
 | 
						|
            "status": True,
 | 
						|
            "collection_name": collection_name,
 | 
						|
            "filename": form_data.url,
 | 
						|
        }
 | 
						|
    except Exception as e:
 | 
						|
        log.exception(e)
 | 
						|
        raise HTTPException(
 | 
						|
            status_code=status.HTTP_400_BAD_REQUEST,
 | 
						|
            detail=ERROR_MESSAGES.DEFAULT(e),
 | 
						|
        )
 | 
						|
 | 
						|
 | 
						|
@app.post("/web")
 | 
						|
def store_web(form_data: UrlForm, user=Depends(get_current_user)):
 | 
						|
    # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
 | 
						|
    try:
 | 
						|
        loader = get_web_loader(
 | 
						|
            form_data.url,
 | 
						|
            verify_ssl=app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
 | 
						|
        )
 | 
						|
        data = loader.load()
 | 
						|
 | 
						|
        collection_name = form_data.collection_name
 | 
						|
        if collection_name == "":
 | 
						|
            collection_name = calculate_sha256_string(form_data.url)[:63]
 | 
						|
 | 
						|
        store_data_in_vector_db(data, collection_name, overwrite=True)
 | 
						|
        return {
 | 
						|
            "status": True,
 | 
						|
            "collection_name": collection_name,
 | 
						|
            "filename": form_data.url,
 | 
						|
        }
 | 
						|
    except Exception as e:
 | 
						|
        log.exception(e)
 | 
						|
        raise HTTPException(
 | 
						|
            status_code=status.HTTP_400_BAD_REQUEST,
 | 
						|
            detail=ERROR_MESSAGES.DEFAULT(e),
 | 
						|
        )
 | 
						|
 | 
						|
 | 
						|
def get_web_loader(url: Union[str, Sequence[str]], verify_ssl: bool = True):
 | 
						|
    # Check if the URL is valid
 | 
						|
    if not validate_url(url):
 | 
						|
        raise ValueError(ERROR_MESSAGES.INVALID_URL)
 | 
						|
    return SafeWebBaseLoader(
 | 
						|
        url,
 | 
						|
        verify_ssl=verify_ssl,
 | 
						|
        requests_per_second=RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
 | 
						|
        continue_on_failure=True,
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
def validate_url(url: Union[str, Sequence[str]]):
 | 
						|
    if isinstance(url, str):
 | 
						|
        if isinstance(validators.url(url), validators.ValidationError):
 | 
						|
            raise ValueError(ERROR_MESSAGES.INVALID_URL)
 | 
						|
        if not ENABLE_RAG_LOCAL_WEB_FETCH:
 | 
						|
            # Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
 | 
						|
            parsed_url = urllib.parse.urlparse(url)
 | 
						|
            # Get IPv4 and IPv6 addresses
 | 
						|
            ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
 | 
						|
            # Check if any of the resolved addresses are private
 | 
						|
            # This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
 | 
						|
            for ip in ipv4_addresses:
 | 
						|
                if validators.ipv4(ip, private=True):
 | 
						|
                    raise ValueError(ERROR_MESSAGES.INVALID_URL)
 | 
						|
            for ip in ipv6_addresses:
 | 
						|
                if validators.ipv6(ip, private=True):
 | 
						|
                    raise ValueError(ERROR_MESSAGES.INVALID_URL)
 | 
						|
        return True
 | 
						|
    elif isinstance(url, Sequence):
 | 
						|
        return all(validate_url(u) for u in url)
 | 
						|
    else:
 | 
						|
        return False
 | 
						|
 | 
						|
 | 
						|
def resolve_hostname(hostname):
 | 
						|
    # Get address information
 | 
						|
    addr_info = socket.getaddrinfo(hostname, None)
 | 
						|
 | 
						|
    # Extract IP addresses from address information
 | 
						|
    ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET]
 | 
						|
    ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6]
 | 
						|
 | 
						|
    return ipv4_addresses, ipv6_addresses
 | 
						|
 | 
						|
 | 
						|
def search_web(engine: str, query: str) -> list[SearchResult]:
 | 
						|
    """Search the web using a search engine and return the results as a list of SearchResult objects.
 | 
						|
    Will look for a search engine API key in environment variables in the following order:
 | 
						|
    - SEARXNG_QUERY_URL
 | 
						|
    - GOOGLE_PSE_API_KEY + GOOGLE_PSE_ENGINE_ID
 | 
						|
    - BRAVE_SEARCH_API_KEY
 | 
						|
    - SERPSTACK_API_KEY
 | 
						|
    - SERPER_API_KEY
 | 
						|
    - SERPLY_API_KEY
 | 
						|
 | 
						|
    Args:
 | 
						|
        query (str): The query to search for
 | 
						|
    """
 | 
						|
 | 
						|
    # TODO: add playwright to search the web
 | 
						|
    if engine == "searxng":
 | 
						|
        if app.state.config.SEARXNG_QUERY_URL:
 | 
						|
            return search_searxng(
 | 
						|
                app.state.config.SEARXNG_QUERY_URL,
 | 
						|
                query,
 | 
						|
                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
 | 
						|
            )
 | 
						|
        else:
 | 
						|
            raise Exception("No SEARXNG_QUERY_URL found in environment variables")
 | 
						|
    elif engine == "google_pse":
 | 
						|
        if (
 | 
						|
            app.state.config.GOOGLE_PSE_API_KEY
 | 
						|
            and app.state.config.GOOGLE_PSE_ENGINE_ID
 | 
						|
        ):
 | 
						|
            return search_google_pse(
 | 
						|
                app.state.config.GOOGLE_PSE_API_KEY,
 | 
						|
                app.state.config.GOOGLE_PSE_ENGINE_ID,
 | 
						|
                query,
 | 
						|
                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
 | 
						|
            )
 | 
						|
        else:
 | 
						|
            raise Exception(
 | 
						|
                "No GOOGLE_PSE_API_KEY or GOOGLE_PSE_ENGINE_ID found in environment variables"
 | 
						|
            )
 | 
						|
    elif engine == "brave":
 | 
						|
        if app.state.config.BRAVE_SEARCH_API_KEY:
 | 
						|
            return search_brave(
 | 
						|
                app.state.config.BRAVE_SEARCH_API_KEY,
 | 
						|
                query,
 | 
						|
                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
 | 
						|
            )
 | 
						|
        else:
 | 
						|
            raise Exception("No BRAVE_SEARCH_API_KEY found in environment variables")
 | 
						|
    elif engine == "serpstack":
 | 
						|
        if app.state.config.SERPSTACK_API_KEY:
 | 
						|
            return search_serpstack(
 | 
						|
                app.state.config.SERPSTACK_API_KEY,
 | 
						|
                query,
 | 
						|
                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
 | 
						|
                https_enabled=app.state.config.SERPSTACK_HTTPS,
 | 
						|
            )
 | 
						|
        else:
 | 
						|
            raise Exception("No SERPSTACK_API_KEY found in environment variables")
 | 
						|
    elif engine == "serper":
 | 
						|
        if app.state.config.SERPER_API_KEY:
 | 
						|
            return search_serper(
 | 
						|
                app.state.config.SERPER_API_KEY,
 | 
						|
                query,
 | 
						|
                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
 | 
						|
            )
 | 
						|
        else:
 | 
						|
            raise Exception("No SERPER_API_KEY found in environment variables")
 | 
						|
    elif engine == "serply":
 | 
						|
        if app.state.config.SERPLY_API_KEY:
 | 
						|
            return search_serply(
 | 
						|
                app.state.config.SERPLY_API_KEY,
 | 
						|
                query,
 | 
						|
                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
 | 
						|
            )
 | 
						|
        else:
 | 
						|
            raise Exception("No SERPLY_API_KEY found in environment variables")
 | 
						|
    elif engine == "duckduckgo":
 | 
						|
        return search_duckduckgo(query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT)
 | 
						|
    else:
 | 
						|
        raise Exception("No search engine API key found in environment variables")
 | 
						|
 | 
						|
 | 
						|
@app.post("/web/search")
 | 
						|
def store_web_search(form_data: SearchForm, user=Depends(get_current_user)):
 | 
						|
    try:
 | 
						|
        logging.info(f"trying to web search with {app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query}")
 | 
						|
        web_results = search_web(
 | 
						|
            app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query
 | 
						|
        )
 | 
						|
    except Exception as e:
 | 
						|
        log.exception(e)
 | 
						|
 | 
						|
        print(e)
 | 
						|
        raise HTTPException(
 | 
						|
            status_code=status.HTTP_400_BAD_REQUEST,
 | 
						|
            detail=ERROR_MESSAGES.WEB_SEARCH_ERROR(e),
 | 
						|
        )
 | 
						|
 | 
						|
    try:
 | 
						|
        urls = [result.link for result in web_results]
 | 
						|
        loader = get_web_loader(urls)
 | 
						|
        data = loader.load()
 | 
						|
 | 
						|
        collection_name = form_data.collection_name
 | 
						|
        if collection_name == "":
 | 
						|
            collection_name = calculate_sha256_string(form_data.query)[:63]
 | 
						|
 | 
						|
        store_data_in_vector_db(data, collection_name, overwrite=True)
 | 
						|
        return {
 | 
						|
            "status": True,
 | 
						|
            "collection_name": collection_name,
 | 
						|
            "filenames": urls,
 | 
						|
        }
 | 
						|
    except Exception as e:
 | 
						|
        log.exception(e)
 | 
						|
        raise HTTPException(
 | 
						|
            status_code=status.HTTP_400_BAD_REQUEST,
 | 
						|
            detail=ERROR_MESSAGES.DEFAULT(e),
 | 
						|
        )
 | 
						|
 | 
						|
 | 
						|
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
 | 
						|
 | 
						|
    text_splitter = RecursiveCharacterTextSplitter(
 | 
						|
        chunk_size=app.state.config.CHUNK_SIZE,
 | 
						|
        chunk_overlap=app.state.config.CHUNK_OVERLAP,
 | 
						|
        add_start_index=True,
 | 
						|
    )
 | 
						|
 | 
						|
    docs = text_splitter.split_documents(data)
 | 
						|
 | 
						|
    if len(docs) > 0:
 | 
						|
        log.info(f"store_data_in_vector_db {docs}")
 | 
						|
        return store_docs_in_vector_db(docs, collection_name, overwrite), None
 | 
						|
    else:
 | 
						|
        raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
 | 
						|
 | 
						|
 | 
						|
def store_text_in_vector_db(
 | 
						|
    text, metadata, collection_name, overwrite: bool = False
 | 
						|
) -> bool:
 | 
						|
    text_splitter = RecursiveCharacterTextSplitter(
 | 
						|
        chunk_size=app.state.config.CHUNK_SIZE,
 | 
						|
        chunk_overlap=app.state.config.CHUNK_OVERLAP,
 | 
						|
        add_start_index=True,
 | 
						|
    )
 | 
						|
    docs = text_splitter.create_documents([text], metadatas=[metadata])
 | 
						|
    return store_docs_in_vector_db(docs, collection_name, overwrite)
 | 
						|
 | 
						|
 | 
						|
def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool:
 | 
						|
    log.info(f"store_docs_in_vector_db {docs} {collection_name}")
 | 
						|
 | 
						|
    texts = [doc.page_content for doc in docs]
 | 
						|
    metadatas = [doc.metadata for doc in docs]
 | 
						|
 | 
						|
    # ChromaDB does not like datetime formats
 | 
						|
    # for meta-data so convert them to string.
 | 
						|
    for metadata in metadatas:
 | 
						|
        for key, value in metadata.items():
 | 
						|
            if isinstance(value, datetime):
 | 
						|
                metadata[key] = str(value)
 | 
						|
 | 
						|
    try:
 | 
						|
        if overwrite:
 | 
						|
            for collection in CHROMA_CLIENT.list_collections():
 | 
						|
                if collection_name == collection.name:
 | 
						|
                    log.info(f"deleting existing collection {collection_name}")
 | 
						|
                    CHROMA_CLIENT.delete_collection(name=collection_name)
 | 
						|
 | 
						|
        collection = CHROMA_CLIENT.create_collection(name=collection_name)
 | 
						|
 | 
						|
        embedding_func = get_embedding_function(
 | 
						|
            app.state.config.RAG_EMBEDDING_ENGINE,
 | 
						|
            app.state.config.RAG_EMBEDDING_MODEL,
 | 
						|
            app.state.sentence_transformer_ef,
 | 
						|
            app.state.config.OPENAI_API_KEY,
 | 
						|
            app.state.config.OPENAI_API_BASE_URL,
 | 
						|
            app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
 | 
						|
        )
 | 
						|
 | 
						|
        embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
 | 
						|
        embeddings = embedding_func(embedding_texts)
 | 
						|
 | 
						|
        for batch in create_batches(
 | 
						|
            api=CHROMA_CLIENT,
 | 
						|
            ids=[str(uuid.uuid4()) for _ in texts],
 | 
						|
            metadatas=metadatas,
 | 
						|
            embeddings=embeddings,
 | 
						|
            documents=texts,
 | 
						|
        ):
 | 
						|
            collection.add(*batch)
 | 
						|
 | 
						|
        return True
 | 
						|
    except Exception as e:
 | 
						|
        log.exception(e)
 | 
						|
        if e.__class__.__name__ == "UniqueConstraintError":
 | 
						|
            return True
 | 
						|
 | 
						|
        return False
 | 
						|
 | 
						|
 | 
						|
def get_loader(filename: str, file_content_type: str, file_path: str):
 | 
						|
    file_ext = filename.split(".")[-1].lower()
 | 
						|
    known_type = True
 | 
						|
 | 
						|
    known_source_ext = [
 | 
						|
        "go",
 | 
						|
        "py",
 | 
						|
        "java",
 | 
						|
        "sh",
 | 
						|
        "bat",
 | 
						|
        "ps1",
 | 
						|
        "cmd",
 | 
						|
        "js",
 | 
						|
        "ts",
 | 
						|
        "css",
 | 
						|
        "cpp",
 | 
						|
        "hpp",
 | 
						|
        "h",
 | 
						|
        "c",
 | 
						|
        "cs",
 | 
						|
        "sql",
 | 
						|
        "log",
 | 
						|
        "ini",
 | 
						|
        "pl",
 | 
						|
        "pm",
 | 
						|
        "r",
 | 
						|
        "dart",
 | 
						|
        "dockerfile",
 | 
						|
        "env",
 | 
						|
        "php",
 | 
						|
        "hs",
 | 
						|
        "hsc",
 | 
						|
        "lua",
 | 
						|
        "nginxconf",
 | 
						|
        "conf",
 | 
						|
        "m",
 | 
						|
        "mm",
 | 
						|
        "plsql",
 | 
						|
        "perl",
 | 
						|
        "rb",
 | 
						|
        "rs",
 | 
						|
        "db2",
 | 
						|
        "scala",
 | 
						|
        "bash",
 | 
						|
        "swift",
 | 
						|
        "vue",
 | 
						|
        "svelte",
 | 
						|
        "msg",
 | 
						|
    ]
 | 
						|
 | 
						|
    if file_ext == "pdf":
 | 
						|
        loader = PyPDFLoader(
 | 
						|
            file_path, extract_images=app.state.config.PDF_EXTRACT_IMAGES
 | 
						|
        )
 | 
						|
    elif file_ext == "csv":
 | 
						|
        loader = CSVLoader(file_path)
 | 
						|
    elif file_ext == "rst":
 | 
						|
        loader = UnstructuredRSTLoader(file_path, mode="elements")
 | 
						|
    elif file_ext == "xml":
 | 
						|
        loader = UnstructuredXMLLoader(file_path)
 | 
						|
    elif file_ext in ["htm", "html"]:
 | 
						|
        loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
 | 
						|
    elif file_ext == "md":
 | 
						|
        loader = UnstructuredMarkdownLoader(file_path)
 | 
						|
    elif file_content_type == "application/epub+zip":
 | 
						|
        loader = UnstructuredEPubLoader(file_path)
 | 
						|
    elif (
 | 
						|
        file_content_type
 | 
						|
        == "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
 | 
						|
        or file_ext in ["doc", "docx"]
 | 
						|
    ):
 | 
						|
        loader = Docx2txtLoader(file_path)
 | 
						|
    elif file_content_type in [
 | 
						|
        "application/vnd.ms-excel",
 | 
						|
        "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
 | 
						|
    ] or file_ext in ["xls", "xlsx"]:
 | 
						|
        loader = UnstructuredExcelLoader(file_path)
 | 
						|
    elif file_content_type in [
 | 
						|
        "application/vnd.ms-powerpoint",
 | 
						|
        "application/vnd.openxmlformats-officedocument.presentationml.presentation",
 | 
						|
    ] or file_ext in ["ppt", "pptx"]:
 | 
						|
        loader = UnstructuredPowerPointLoader(file_path)
 | 
						|
    elif file_ext == "msg":
 | 
						|
        loader = OutlookMessageLoader(file_path)
 | 
						|
    elif file_ext in known_source_ext or (
 | 
						|
        file_content_type and file_content_type.find("text/") >= 0
 | 
						|
    ):
 | 
						|
        loader = TextLoader(file_path, autodetect_encoding=True)
 | 
						|
    else:
 | 
						|
        loader = TextLoader(file_path, autodetect_encoding=True)
 | 
						|
        known_type = False
 | 
						|
 | 
						|
    return loader, known_type
 | 
						|
 | 
						|
 | 
						|
@app.post("/doc")
 | 
						|
def store_doc(
 | 
						|
    collection_name: Optional[str] = Form(None),
 | 
						|
    file: UploadFile = File(...),
 | 
						|
    user=Depends(get_current_user),
 | 
						|
):
 | 
						|
    # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
 | 
						|
 | 
						|
    log.info(f"file.content_type: {file.content_type}")
 | 
						|
    try:
 | 
						|
        unsanitized_filename = file.filename
 | 
						|
        filename = os.path.basename(unsanitized_filename)
 | 
						|
 | 
						|
        file_path = f"{UPLOAD_DIR}/{filename}"
 | 
						|
 | 
						|
        contents = file.file.read()
 | 
						|
        with open(file_path, "wb") as f:
 | 
						|
            f.write(contents)
 | 
						|
            f.close()
 | 
						|
 | 
						|
        f = open(file_path, "rb")
 | 
						|
        if collection_name == None:
 | 
						|
            collection_name = calculate_sha256(f)[:63]
 | 
						|
        f.close()
 | 
						|
 | 
						|
        loader, known_type = get_loader(filename, file.content_type, file_path)
 | 
						|
        data = loader.load()
 | 
						|
 | 
						|
        try:
 | 
						|
            result = store_data_in_vector_db(data, collection_name)
 | 
						|
 | 
						|
            if result:
 | 
						|
                return {
 | 
						|
                    "status": True,
 | 
						|
                    "collection_name": collection_name,
 | 
						|
                    "filename": filename,
 | 
						|
                    "known_type": known_type,
 | 
						|
                }
 | 
						|
        except Exception as e:
 | 
						|
            raise HTTPException(
 | 
						|
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
 | 
						|
                detail=e,
 | 
						|
            )
 | 
						|
    except Exception as e:
 | 
						|
        log.exception(e)
 | 
						|
        if "No pandoc was found" in str(e):
 | 
						|
            raise HTTPException(
 | 
						|
                status_code=status.HTTP_400_BAD_REQUEST,
 | 
						|
                detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
 | 
						|
            )
 | 
						|
        else:
 | 
						|
            raise HTTPException(
 | 
						|
                status_code=status.HTTP_400_BAD_REQUEST,
 | 
						|
                detail=ERROR_MESSAGES.DEFAULT(e),
 | 
						|
            )
 | 
						|
 | 
						|
 | 
						|
class TextRAGForm(BaseModel):
 | 
						|
    name: str
 | 
						|
    content: str
 | 
						|
    collection_name: Optional[str] = None
 | 
						|
 | 
						|
 | 
						|
@app.post("/text")
 | 
						|
def store_text(
 | 
						|
    form_data: TextRAGForm,
 | 
						|
    user=Depends(get_current_user),
 | 
						|
):
 | 
						|
 | 
						|
    collection_name = form_data.collection_name
 | 
						|
    if collection_name == None:
 | 
						|
        collection_name = calculate_sha256_string(form_data.content)
 | 
						|
 | 
						|
    result = store_text_in_vector_db(
 | 
						|
        form_data.content,
 | 
						|
        metadata={"name": form_data.name, "created_by": user.id},
 | 
						|
        collection_name=collection_name,
 | 
						|
    )
 | 
						|
 | 
						|
    if result:
 | 
						|
        return {"status": True, "collection_name": collection_name}
 | 
						|
    else:
 | 
						|
        raise HTTPException(
 | 
						|
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
 | 
						|
            detail=ERROR_MESSAGES.DEFAULT(),
 | 
						|
        )
 | 
						|
 | 
						|
 | 
						|
@app.get("/scan")
 | 
						|
def scan_docs_dir(user=Depends(get_admin_user)):
 | 
						|
    for path in Path(DOCS_DIR).rglob("./**/*"):
 | 
						|
        try:
 | 
						|
            if path.is_file() and not path.name.startswith("."):
 | 
						|
                tags = extract_folders_after_data_docs(path)
 | 
						|
                filename = path.name
 | 
						|
                file_content_type = mimetypes.guess_type(path)
 | 
						|
 | 
						|
                f = open(path, "rb")
 | 
						|
                collection_name = calculate_sha256(f)[:63]
 | 
						|
                f.close()
 | 
						|
 | 
						|
                loader, known_type = get_loader(
 | 
						|
                    filename, file_content_type[0], str(path)
 | 
						|
                )
 | 
						|
                data = loader.load()
 | 
						|
 | 
						|
                try:
 | 
						|
                    result = store_data_in_vector_db(data, collection_name)
 | 
						|
 | 
						|
                    if result:
 | 
						|
                        sanitized_filename = sanitize_filename(filename)
 | 
						|
                        doc = Documents.get_doc_by_name(sanitized_filename)
 | 
						|
 | 
						|
                        if doc == None:
 | 
						|
                            doc = Documents.insert_new_doc(
 | 
						|
                                user.id,
 | 
						|
                                DocumentForm(
 | 
						|
                                    **{
 | 
						|
                                        "name": sanitized_filename,
 | 
						|
                                        "title": filename,
 | 
						|
                                        "collection_name": collection_name,
 | 
						|
                                        "filename": filename,
 | 
						|
                                        "content": (
 | 
						|
                                            json.dumps(
 | 
						|
                                                {
 | 
						|
                                                    "tags": list(
 | 
						|
                                                        map(
 | 
						|
                                                            lambda name: {"name": name},
 | 
						|
                                                            tags,
 | 
						|
                                                        )
 | 
						|
                                                    )
 | 
						|
                                                }
 | 
						|
                                            )
 | 
						|
                                            if len(tags)
 | 
						|
                                            else "{}"
 | 
						|
                                        ),
 | 
						|
                                    }
 | 
						|
                                ),
 | 
						|
                            )
 | 
						|
                except Exception as e:
 | 
						|
                    log.exception(e)
 | 
						|
                    pass
 | 
						|
 | 
						|
        except Exception as e:
 | 
						|
            log.exception(e)
 | 
						|
 | 
						|
    return True
 | 
						|
 | 
						|
 | 
						|
@app.get("/reset/db")
 | 
						|
def reset_vector_db(user=Depends(get_admin_user)):
 | 
						|
    CHROMA_CLIENT.reset()
 | 
						|
 | 
						|
 | 
						|
@app.get("/reset/uploads")
 | 
						|
def reset_upload_dir(user=Depends(get_admin_user)) -> bool:
 | 
						|
    folder = f"{UPLOAD_DIR}"
 | 
						|
    try:
 | 
						|
        # Check if the directory exists
 | 
						|
        if os.path.exists(folder):
 | 
						|
            # Iterate over all the files and directories in the specified directory
 | 
						|
            for filename in os.listdir(folder):
 | 
						|
                file_path = os.path.join(folder, filename)
 | 
						|
                try:
 | 
						|
                    if os.path.isfile(file_path) or os.path.islink(file_path):
 | 
						|
                        os.unlink(file_path)  # Remove the file or link
 | 
						|
                    elif os.path.isdir(file_path):
 | 
						|
                        shutil.rmtree(file_path)  # Remove the directory
 | 
						|
                except Exception as e:
 | 
						|
                    print(f"Failed to delete {file_path}. Reason: {e}")
 | 
						|
        else:
 | 
						|
            print(f"The directory {folder} does not exist")
 | 
						|
    except Exception as e:
 | 
						|
        print(f"Failed to process the directory {folder}. Reason: {e}")
 | 
						|
 | 
						|
    return True
 | 
						|
 | 
						|
 | 
						|
@app.get("/reset")
 | 
						|
def reset(user=Depends(get_admin_user)) -> bool:
 | 
						|
    folder = f"{UPLOAD_DIR}"
 | 
						|
    for filename in os.listdir(folder):
 | 
						|
        file_path = os.path.join(folder, filename)
 | 
						|
        try:
 | 
						|
            if os.path.isfile(file_path) or os.path.islink(file_path):
 | 
						|
                os.unlink(file_path)
 | 
						|
            elif os.path.isdir(file_path):
 | 
						|
                shutil.rmtree(file_path)
 | 
						|
        except Exception as e:
 | 
						|
            log.error("Failed to delete %s. Reason: %s" % (file_path, e))
 | 
						|
 | 
						|
    try:
 | 
						|
        CHROMA_CLIENT.reset()
 | 
						|
    except Exception as e:
 | 
						|
        log.exception(e)
 | 
						|
 | 
						|
    return True
 | 
						|
 | 
						|
class SafeWebBaseLoader(WebBaseLoader):
 | 
						|
    """WebBaseLoader with enhanced error handling for URLs."""
 | 
						|
    def lazy_load(self) -> Iterator[Document]:
 | 
						|
        """Lazy load text from the url(s) in web_path with error handling."""
 | 
						|
        for path in self.web_paths:
 | 
						|
            try:
 | 
						|
                soup = self._scrape(path, bs_kwargs=self.bs_kwargs)
 | 
						|
                text = soup.get_text(**self.bs_get_text_kwargs)
 | 
						|
 | 
						|
                # Build metadata
 | 
						|
                metadata = {"source": path}
 | 
						|
                if title := soup.find("title"):
 | 
						|
                    metadata["title"] = title.get_text()
 | 
						|
                if description := soup.find("meta", attrs={"name": "description"}):
 | 
						|
                    metadata["description"] = description.get("content", "No description found.")
 | 
						|
                if html := soup.find("html"):
 | 
						|
                    metadata["language"] = html.get("lang", "No language found.")
 | 
						|
                
 | 
						|
                yield Document(page_content=text, metadata=metadata)
 | 
						|
            except Exception as e:
 | 
						|
                # Log the error and continue with the next URL
 | 
						|
                log.error(f"Error loading {path}: {e}")
 | 
						|
                
 | 
						|
if ENV == "dev":
 | 
						|
 | 
						|
    @app.get("/ef")
 | 
						|
    async def get_embeddings():
 | 
						|
        return {"result": app.state.EMBEDDING_FUNCTION("hello world")}
 | 
						|
 | 
						|
    @app.get("/ef/{text}")
 | 
						|
    async def get_embeddings_text(text: str):
 | 
						|
        return {"result": app.state.EMBEDDING_FUNCTION(text)}
 |