qdrant client improvements

This commit is contained in:
expruc 2025-08-09 21:12:30 +03:00
parent f4d2c6027a
commit 88abd01b87
3 changed files with 28 additions and 7 deletions

View File

@ -1924,6 +1924,8 @@ QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", None)
QDRANT_ON_DISK = os.environ.get("QDRANT_ON_DISK", "false").lower() == "true" QDRANT_ON_DISK = os.environ.get("QDRANT_ON_DISK", "false").lower() == "true"
QDRANT_PREFER_GRPC = os.environ.get("QDRANT_PREFER_GRPC", "false").lower() == "true" QDRANT_PREFER_GRPC = os.environ.get("QDRANT_PREFER_GRPC", "false").lower() == "true"
QDRANT_GRPC_PORT = int(os.environ.get("QDRANT_GRPC_PORT", "6334")) QDRANT_GRPC_PORT = int(os.environ.get("QDRANT_GRPC_PORT", "6334"))
QDRANT_TIMEOUT = int(os.environ.get("QDRANT_TIMEOUT", "5"))
QDRANT_HNSW_M = int(os.environ.get("QDRANT_HNSW_M", "16"))
ENABLE_QDRANT_MULTITENANCY_MODE = ( ENABLE_QDRANT_MULTITENANCY_MODE = (
os.environ.get("ENABLE_QDRANT_MULTITENANCY_MODE", "true").lower() == "true" os.environ.get("ENABLE_QDRANT_MULTITENANCY_MODE", "true").lower() == "true"
) )

View File

@ -19,6 +19,8 @@ from open_webui.config import (
QDRANT_GRPC_PORT, QDRANT_GRPC_PORT,
QDRANT_PREFER_GRPC, QDRANT_PREFER_GRPC,
QDRANT_COLLECTION_PREFIX, QDRANT_COLLECTION_PREFIX,
QDRANT_TIMEOUT,
QDRANT_HNSW_M,
) )
from open_webui.env import SRC_LOG_LEVELS from open_webui.env import SRC_LOG_LEVELS
@ -36,6 +38,8 @@ class QdrantClient(VectorDBBase):
self.QDRANT_ON_DISK = QDRANT_ON_DISK self.QDRANT_ON_DISK = QDRANT_ON_DISK
self.PREFER_GRPC = QDRANT_PREFER_GRPC self.PREFER_GRPC = QDRANT_PREFER_GRPC
self.GRPC_PORT = QDRANT_GRPC_PORT self.GRPC_PORT = QDRANT_GRPC_PORT
self.QDRANT_TIMEOUT = QDRANT_TIMEOUT
self.QDRANT_HNSW_M = QDRANT_HNSW_M
if not self.QDRANT_URI: if not self.QDRANT_URI:
self.client = None self.client = None
@ -53,9 +57,10 @@ class QdrantClient(VectorDBBase):
grpc_port=self.GRPC_PORT, grpc_port=self.GRPC_PORT,
prefer_grpc=self.PREFER_GRPC, prefer_grpc=self.PREFER_GRPC,
api_key=self.QDRANT_API_KEY, api_key=self.QDRANT_API_KEY,
timeout=self.QDRANT_TIMEOUT,
) )
else: else:
self.client = Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY) self.client = Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY, timeout=QDRANT_TIMEOUT,)
def _result_to_get_result(self, points) -> GetResult: def _result_to_get_result(self, points) -> GetResult:
ids = [] ids = []
@ -85,6 +90,9 @@ class QdrantClient(VectorDBBase):
distance=models.Distance.COSINE, distance=models.Distance.COSINE,
on_disk=self.QDRANT_ON_DISK, on_disk=self.QDRANT_ON_DISK,
), ),
hnsw_config=models.HnswConfigDiff(
m=self.QDRANT_HNSW_M,
),
) )
# Create payload indexes for efficient filtering # Create payload indexes for efficient filtering
@ -183,11 +191,11 @@ class QdrantClient(VectorDBBase):
def get(self, collection_name: str) -> Optional[GetResult]: def get(self, collection_name: str) -> Optional[GetResult]:
# Get all the items in the collection. # Get all the items in the collection.
points = self.client.query_points( points = self.client.scroll(
collection_name=f"{self.collection_prefix}_{collection_name}", collection_name=f"{self.collection_prefix}_{collection_name}",
limit=NO_LIMIT, # otherwise qdrant would set limit to 10! limit=NO_LIMIT, # otherwise qdrant would set limit to 10!
) )
return self._result_to_get_result(points.points) return self._result_to_get_result(points[0])
def insert(self, collection_name: str, items: list[VectorItem]): def insert(self, collection_name: str, items: list[VectorItem]):
# Insert the items into the collection, if the collection does not exist, it will be created. # Insert the items into the collection, if the collection does not exist, it will be created.

View File

@ -10,6 +10,8 @@ from open_webui.config import (
QDRANT_PREFER_GRPC, QDRANT_PREFER_GRPC,
QDRANT_URI, QDRANT_URI,
QDRANT_COLLECTION_PREFIX, QDRANT_COLLECTION_PREFIX,
QDRANT_TIMEOUT,
QDRANT_HNSW_M,
) )
from open_webui.env import SRC_LOG_LEVELS from open_webui.env import SRC_LOG_LEVELS
from open_webui.retrieval.vector.main import ( from open_webui.retrieval.vector.main import (
@ -51,6 +53,8 @@ class QdrantClient(VectorDBBase):
self.QDRANT_ON_DISK = QDRANT_ON_DISK self.QDRANT_ON_DISK = QDRANT_ON_DISK
self.PREFER_GRPC = QDRANT_PREFER_GRPC self.PREFER_GRPC = QDRANT_PREFER_GRPC
self.GRPC_PORT = QDRANT_GRPC_PORT self.GRPC_PORT = QDRANT_GRPC_PORT
self.QDRANT_TIMEOUT = QDRANT_TIMEOUT
self.QDRANT_HNSW_M = QDRANT_HNSW_M
if not self.QDRANT_URI: if not self.QDRANT_URI:
raise ValueError( raise ValueError(
@ -69,9 +73,10 @@ class QdrantClient(VectorDBBase):
grpc_port=self.GRPC_PORT, grpc_port=self.GRPC_PORT,
prefer_grpc=self.PREFER_GRPC, prefer_grpc=self.PREFER_GRPC,
api_key=self.QDRANT_API_KEY, api_key=self.QDRANT_API_KEY,
timeout=self.QDRANT_TIMEOUT,
) )
if self.PREFER_GRPC if self.PREFER_GRPC
else Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY) else Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY, timeout=self.QDRANT_TIMEOUT,)
) )
# Main collection types for multi-tenancy # Main collection types for multi-tenancy
@ -133,6 +138,12 @@ class QdrantClient(VectorDBBase):
distance=models.Distance.COSINE, distance=models.Distance.COSINE,
on_disk=self.QDRANT_ON_DISK, on_disk=self.QDRANT_ON_DISK,
), ),
# Disable global index building due to multitenancy
# For more details https://qdrant.tech/documentation/guides/multiple-partitions/#calibrate-performance
hnsw_config=models.HnswConfigDiff(
payload_m=self.QDRANT_HNSW_M,
m=0,
),
) )
log.info( log.info(
f"Multi-tenant collection {mt_collection_name} created with dimension {dimension}!" f"Multi-tenant collection {mt_collection_name} created with dimension {dimension}!"
@ -296,12 +307,12 @@ class QdrantClient(VectorDBBase):
log.debug(f"Collection {mt_collection} doesn't exist, get returns None") log.debug(f"Collection {mt_collection} doesn't exist, get returns None")
return None return None
tenant_filter = _tenant_filter(tenant_id) tenant_filter = _tenant_filter(tenant_id)
points = self.client.query_points( points = self.client.scroll(
collection_name=mt_collection, collection_name=mt_collection,
query_filter=models.Filter(must=[tenant_filter]), scroll_filter=models.Filter(must=[tenant_filter]),
limit=NO_LIMIT, limit=NO_LIMIT,
) )
return self._result_to_get_result(points.points) return self._result_to_get_result(points[0])
def upsert(self, collection_name: str, items: List[VectorItem]): def upsert(self, collection_name: str, items: List[VectorItem]):
""" """