251 lines
		
	
	
		
			9.0 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			251 lines
		
	
	
		
			9.0 KiB
		
	
	
	
		
			Python
		
	
	
	
from typing import Optional
 | 
						|
import logging
 | 
						|
from urllib.parse import urlparse
 | 
						|
 | 
						|
from qdrant_client import QdrantClient as Qclient
 | 
						|
from qdrant_client.http.models import PointStruct
 | 
						|
from qdrant_client.models import models
 | 
						|
 | 
						|
from open_webui.retrieval.vector.main import (
 | 
						|
    VectorDBBase,
 | 
						|
    VectorItem,
 | 
						|
    SearchResult,
 | 
						|
    GetResult,
 | 
						|
)
 | 
						|
from open_webui.config import (
 | 
						|
    QDRANT_URI,
 | 
						|
    QDRANT_API_KEY,
 | 
						|
    QDRANT_ON_DISK,
 | 
						|
    QDRANT_GRPC_PORT,
 | 
						|
    QDRANT_PREFER_GRPC,
 | 
						|
    QDRANT_COLLECTION_PREFIX,
 | 
						|
    QDRANT_TIMEOUT,
 | 
						|
    QDRANT_HNSW_M,
 | 
						|
)
 | 
						|
from open_webui.env import SRC_LOG_LEVELS
 | 
						|
 | 
						|
NO_LIMIT = 999999999
 | 
						|
 | 
						|
log = logging.getLogger(__name__)
 | 
						|
log.setLevel(SRC_LOG_LEVELS["RAG"])
 | 
						|
 | 
						|
 | 
						|
class QdrantClient(VectorDBBase):
 | 
						|
    def __init__(self):
 | 
						|
        self.collection_prefix = QDRANT_COLLECTION_PREFIX
 | 
						|
        self.QDRANT_URI = QDRANT_URI
 | 
						|
        self.QDRANT_API_KEY = QDRANT_API_KEY
 | 
						|
        self.QDRANT_ON_DISK = QDRANT_ON_DISK
 | 
						|
        self.PREFER_GRPC = QDRANT_PREFER_GRPC
 | 
						|
        self.GRPC_PORT = QDRANT_GRPC_PORT
 | 
						|
        self.QDRANT_TIMEOUT = QDRANT_TIMEOUT
 | 
						|
        self.QDRANT_HNSW_M = QDRANT_HNSW_M
 | 
						|
 | 
						|
        if not self.QDRANT_URI:
 | 
						|
            self.client = None
 | 
						|
            return
 | 
						|
 | 
						|
        # Unified handling for either scheme
 | 
						|
        parsed = urlparse(self.QDRANT_URI)
 | 
						|
        host = parsed.hostname or self.QDRANT_URI
 | 
						|
        http_port = parsed.port or 6333  # default REST port
 | 
						|
 | 
						|
        if self.PREFER_GRPC:
 | 
						|
            self.client = Qclient(
 | 
						|
                host=host,
 | 
						|
                port=http_port,
 | 
						|
                grpc_port=self.GRPC_PORT,
 | 
						|
                prefer_grpc=self.PREFER_GRPC,
 | 
						|
                api_key=self.QDRANT_API_KEY,
 | 
						|
                timeout=self.QDRANT_TIMEOUT,
 | 
						|
            )
 | 
						|
        else:
 | 
						|
            self.client = Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY, timeout=QDRANT_TIMEOUT,)
 | 
						|
 | 
						|
    def _result_to_get_result(self, points) -> GetResult:
 | 
						|
        ids = []
 | 
						|
        documents = []
 | 
						|
        metadatas = []
 | 
						|
 | 
						|
        for point in points:
 | 
						|
            payload = point.payload
 | 
						|
            ids.append(point.id)
 | 
						|
            documents.append(payload["text"])
 | 
						|
            metadatas.append(payload["metadata"])
 | 
						|
 | 
						|
        return GetResult(
 | 
						|
            **{
 | 
						|
                "ids": [ids],
 | 
						|
                "documents": [documents],
 | 
						|
                "metadatas": [metadatas],
 | 
						|
            }
 | 
						|
        )
 | 
						|
 | 
						|
    def _create_collection(self, collection_name: str, dimension: int):
 | 
						|
        collection_name_with_prefix = f"{self.collection_prefix}_{collection_name}"
 | 
						|
        self.client.create_collection(
 | 
						|
            collection_name=collection_name_with_prefix,
 | 
						|
            vectors_config=models.VectorParams(
 | 
						|
                size=dimension,
 | 
						|
                distance=models.Distance.COSINE,
 | 
						|
                on_disk=self.QDRANT_ON_DISK,
 | 
						|
            ),
 | 
						|
            hnsw_config=models.HnswConfigDiff(
 | 
						|
                m=self.QDRANT_HNSW_M,
 | 
						|
            ),
 | 
						|
        )
 | 
						|
 | 
						|
        # Create payload indexes for efficient filtering
 | 
						|
        self.client.create_payload_index(
 | 
						|
            collection_name=collection_name_with_prefix,
 | 
						|
            field_name="metadata.hash",
 | 
						|
            field_schema=models.KeywordIndexParams(
 | 
						|
                type=models.KeywordIndexType.KEYWORD,
 | 
						|
                is_tenant=False,
 | 
						|
                on_disk=self.QDRANT_ON_DISK,
 | 
						|
            ),
 | 
						|
        )
 | 
						|
        self.client.create_payload_index(
 | 
						|
            collection_name=collection_name_with_prefix,
 | 
						|
            field_name="metadata.file_id",
 | 
						|
            field_schema=models.KeywordIndexParams(
 | 
						|
                type=models.KeywordIndexType.KEYWORD,
 | 
						|
                is_tenant=False,
 | 
						|
                on_disk=self.QDRANT_ON_DISK,
 | 
						|
            ),
 | 
						|
        )
 | 
						|
        log.info(f"collection {collection_name_with_prefix} successfully created!")
 | 
						|
 | 
						|
    def _create_collection_if_not_exists(self, collection_name, dimension):
 | 
						|
        if not self.has_collection(collection_name=collection_name):
 | 
						|
            self._create_collection(
 | 
						|
                collection_name=collection_name, dimension=dimension
 | 
						|
            )
 | 
						|
 | 
						|
    def _create_points(self, items: list[VectorItem]):
 | 
						|
        return [
 | 
						|
            PointStruct(
 | 
						|
                id=item["id"],
 | 
						|
                vector=item["vector"],
 | 
						|
                payload={"text": item["text"], "metadata": item["metadata"]},
 | 
						|
            )
 | 
						|
            for item in items
 | 
						|
        ]
 | 
						|
 | 
						|
    def has_collection(self, collection_name: str) -> bool:
 | 
						|
        return self.client.collection_exists(
 | 
						|
            f"{self.collection_prefix}_{collection_name}"
 | 
						|
        )
 | 
						|
 | 
						|
    def delete_collection(self, collection_name: str):
 | 
						|
        return self.client.delete_collection(
 | 
						|
            collection_name=f"{self.collection_prefix}_{collection_name}"
 | 
						|
        )
 | 
						|
 | 
						|
    def search(
 | 
						|
        self, collection_name: str, vectors: list[list[float | int]], limit: int
 | 
						|
    ) -> Optional[SearchResult]:
 | 
						|
        # Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
 | 
						|
        if limit is None:
 | 
						|
            limit = NO_LIMIT  # otherwise qdrant would set limit to 10!
 | 
						|
 | 
						|
        query_response = self.client.query_points(
 | 
						|
            collection_name=f"{self.collection_prefix}_{collection_name}",
 | 
						|
            query=vectors[0],
 | 
						|
            limit=limit,
 | 
						|
        )
 | 
						|
        get_result = self._result_to_get_result(query_response.points)
 | 
						|
        return SearchResult(
 | 
						|
            ids=get_result.ids,
 | 
						|
            documents=get_result.documents,
 | 
						|
            metadatas=get_result.metadatas,
 | 
						|
            # qdrant distance is [-1, 1], normalize to [0, 1]
 | 
						|
            distances=[[(point.score + 1.0) / 2.0 for point in query_response.points]],
 | 
						|
        )
 | 
						|
 | 
						|
    def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
 | 
						|
        # Construct the filter string for querying
 | 
						|
        if not self.has_collection(collection_name):
 | 
						|
            return None
 | 
						|
        try:
 | 
						|
            if limit is None:
 | 
						|
                limit = NO_LIMIT  # otherwise qdrant would set limit to 10!
 | 
						|
 | 
						|
            field_conditions = []
 | 
						|
            for key, value in filter.items():
 | 
						|
                field_conditions.append(
 | 
						|
                    models.FieldCondition(
 | 
						|
                        key=f"metadata.{key}", match=models.MatchValue(value=value)
 | 
						|
                    )
 | 
						|
                )
 | 
						|
 | 
						|
            points = self.client.scroll(
 | 
						|
                collection_name=f"{self.collection_prefix}_{collection_name}",
 | 
						|
                scroll_filter=models.Filter(should=field_conditions),
 | 
						|
                limit=limit,
 | 
						|
            )
 | 
						|
            return self._result_to_get_result(points[0])
 | 
						|
        except Exception as e:
 | 
						|
            log.exception(f"Error querying a collection '{collection_name}': {e}")
 | 
						|
            return None
 | 
						|
 | 
						|
    def get(self, collection_name: str) -> Optional[GetResult]:
 | 
						|
        # Get all the items in the collection.
 | 
						|
        points = self.client.scroll(
 | 
						|
            collection_name=f"{self.collection_prefix}_{collection_name}",
 | 
						|
            limit=NO_LIMIT,  # otherwise qdrant would set limit to 10!
 | 
						|
        )
 | 
						|
        return self._result_to_get_result(points[0])
 | 
						|
 | 
						|
    def insert(self, collection_name: str, items: list[VectorItem]):
 | 
						|
        # Insert the items into the collection, if the collection does not exist, it will be created.
 | 
						|
        self._create_collection_if_not_exists(collection_name, len(items[0]["vector"]))
 | 
						|
        points = self._create_points(items)
 | 
						|
        self.client.upload_points(f"{self.collection_prefix}_{collection_name}", points)
 | 
						|
 | 
						|
    def upsert(self, collection_name: str, items: list[VectorItem]):
 | 
						|
        # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
 | 
						|
        self._create_collection_if_not_exists(collection_name, len(items[0]["vector"]))
 | 
						|
        points = self._create_points(items)
 | 
						|
        return self.client.upsert(f"{self.collection_prefix}_{collection_name}", points)
 | 
						|
 | 
						|
    def delete(
 | 
						|
        self,
 | 
						|
        collection_name: str,
 | 
						|
        ids: Optional[list[str]] = None,
 | 
						|
        filter: Optional[dict] = None,
 | 
						|
    ):
 | 
						|
        # Delete the items from the collection based on the ids.
 | 
						|
        field_conditions = []
 | 
						|
 | 
						|
        if ids:
 | 
						|
            for id_value in ids:
 | 
						|
                field_conditions.append(
 | 
						|
                    models.FieldCondition(
 | 
						|
                        key="metadata.id",
 | 
						|
                        match=models.MatchValue(value=id_value),
 | 
						|
                    ),
 | 
						|
                ),
 | 
						|
        elif filter:
 | 
						|
            for key, value in filter.items():
 | 
						|
                field_conditions.append(
 | 
						|
                    models.FieldCondition(
 | 
						|
                        key=f"metadata.{key}",
 | 
						|
                        match=models.MatchValue(value=value),
 | 
						|
                    ),
 | 
						|
                ),
 | 
						|
 | 
						|
        return self.client.delete(
 | 
						|
            collection_name=f"{self.collection_prefix}_{collection_name}",
 | 
						|
            points_selector=models.FilterSelector(
 | 
						|
                filter=models.Filter(must=field_conditions)
 | 
						|
            ),
 | 
						|
        )
 | 
						|
 | 
						|
    def reset(self):
 | 
						|
        # Resets the database. This will delete all collections and item entries.
 | 
						|
        collection_names = self.client.get_collections().collections
 | 
						|
        for collection_name in collection_names:
 | 
						|
            if collection_name.name.startswith(self.collection_prefix):
 | 
						|
                self.client.delete_collection(collection_name=collection_name.name)
 |