283 lines
9.9 KiB
Python
283 lines
9.9 KiB
Python
import logging
|
|
from typing import Optional, Tuple, List, Dict, Any
|
|
|
|
from open_webui.config import (
|
|
MILVUS_URI,
|
|
MILVUS_TOKEN,
|
|
MILVUS_DB,
|
|
MILVUS_COLLECTION_PREFIX,
|
|
MILVUS_INDEX_TYPE,
|
|
MILVUS_METRIC_TYPE,
|
|
MILVUS_HNSW_M,
|
|
MILVUS_HNSW_EFCONSTRUCTION,
|
|
MILVUS_IVF_FLAT_NLIST,
|
|
)
|
|
from open_webui.env import SRC_LOG_LEVELS
|
|
from open_webui.retrieval.vector.main import (
|
|
GetResult,
|
|
SearchResult,
|
|
VectorDBBase,
|
|
VectorItem,
|
|
)
|
|
from pymilvus import (
|
|
connections,
|
|
utility,
|
|
Collection,
|
|
CollectionSchema,
|
|
FieldSchema,
|
|
DataType,
|
|
)
|
|
|
|
log = logging.getLogger(__name__)
|
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
|
|
|
RESOURCE_ID_FIELD = "resource_id"
|
|
|
|
|
|
class MilvusClient(VectorDBBase):
|
|
def __init__(self):
|
|
# Milvus collection names can only contain numbers, letters, and underscores.
|
|
self.collection_prefix = MILVUS_COLLECTION_PREFIX.replace("-", "_")
|
|
connections.connect(
|
|
alias="default",
|
|
uri=MILVUS_URI,
|
|
token=MILVUS_TOKEN,
|
|
db_name=MILVUS_DB,
|
|
)
|
|
|
|
# Main collection types for multi-tenancy
|
|
self.MEMORY_COLLECTION = f"{self.collection_prefix}_memories"
|
|
self.KNOWLEDGE_COLLECTION = f"{self.collection_prefix}_knowledge"
|
|
self.FILE_COLLECTION = f"{self.collection_prefix}_files"
|
|
self.WEB_SEARCH_COLLECTION = f"{self.collection_prefix}_web_search"
|
|
self.HASH_BASED_COLLECTION = f"{self.collection_prefix}_hash_based"
|
|
self.shared_collections = [
|
|
self.MEMORY_COLLECTION,
|
|
self.KNOWLEDGE_COLLECTION,
|
|
self.FILE_COLLECTION,
|
|
self.WEB_SEARCH_COLLECTION,
|
|
self.HASH_BASED_COLLECTION,
|
|
]
|
|
|
|
def _get_collection_and_resource_id(self, collection_name: str) -> Tuple[str, str]:
|
|
"""
|
|
Maps the traditional collection name to multi-tenant collection and resource ID.
|
|
|
|
WARNING: This mapping relies on current Open WebUI naming conventions for
|
|
collection names. If Open WebUI changes how it generates collection names
|
|
(e.g., "user-memory-" prefix, "file-" prefix, web search patterns, or hash
|
|
formats), this mapping will break and route data to incorrect collections.
|
|
POTENTIALLY CAUSING HUGE DATA CORRUPTION, DATA CONSISTENCY ISSUES AND INCORRECT
|
|
DATA MAPPING INSIDE THE DATABASE.
|
|
"""
|
|
resource_id = collection_name
|
|
|
|
if collection_name.startswith("user-memory-"):
|
|
return self.MEMORY_COLLECTION, resource_id
|
|
elif collection_name.startswith("file-"):
|
|
return self.FILE_COLLECTION, resource_id
|
|
elif collection_name.startswith("web-search-"):
|
|
return self.WEB_SEARCH_COLLECTION, resource_id
|
|
elif len(collection_name) == 63 and all(
|
|
c in "0123456789abcdef" for c in collection_name
|
|
):
|
|
return self.HASH_BASED_COLLECTION, resource_id
|
|
else:
|
|
return self.KNOWLEDGE_COLLECTION, resource_id
|
|
|
|
def _create_shared_collection(self, mt_collection_name: str, dimension: int):
|
|
fields = [
|
|
FieldSchema(
|
|
name="id",
|
|
dtype=DataType.VARCHAR,
|
|
is_primary=True,
|
|
auto_id=False,
|
|
max_length=36,
|
|
),
|
|
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=dimension),
|
|
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
|
|
FieldSchema(name="metadata", dtype=DataType.JSON),
|
|
FieldSchema(name=RESOURCE_ID_FIELD, dtype=DataType.VARCHAR, max_length=255),
|
|
]
|
|
schema = CollectionSchema(fields, "Shared collection for multi-tenancy")
|
|
collection = Collection(mt_collection_name, schema)
|
|
|
|
index_params = {
|
|
"metric_type": MILVUS_METRIC_TYPE,
|
|
"index_type": MILVUS_INDEX_TYPE,
|
|
"params": {},
|
|
}
|
|
if MILVUS_INDEX_TYPE == "HNSW":
|
|
index_params["params"] = {
|
|
"M": MILVUS_HNSW_M,
|
|
"efConstruction": MILVUS_HNSW_EFCONSTRUCTION,
|
|
}
|
|
elif MILVUS_INDEX_TYPE == "IVF_FLAT":
|
|
index_params["params"] = {"nlist": MILVUS_IVF_FLAT_NLIST}
|
|
|
|
collection.create_index("vector", index_params)
|
|
collection.create_index(RESOURCE_ID_FIELD)
|
|
log.info(f"Created shared collection: {mt_collection_name}")
|
|
return collection
|
|
|
|
def _ensure_collection(self, mt_collection_name: str, dimension: int):
|
|
if not utility.has_collection(mt_collection_name):
|
|
self._create_shared_collection(mt_collection_name, dimension)
|
|
|
|
def has_collection(self, collection_name: str) -> bool:
|
|
mt_collection, resource_id = self._get_collection_and_resource_id(
|
|
collection_name
|
|
)
|
|
if not utility.has_collection(mt_collection):
|
|
return False
|
|
|
|
collection = Collection(mt_collection)
|
|
collection.load()
|
|
res = collection.query(expr=f"{RESOURCE_ID_FIELD} == '{resource_id}'", limit=1)
|
|
return len(res) > 0
|
|
|
|
def upsert(self, collection_name: str, items: List[VectorItem]):
|
|
if not items:
|
|
return
|
|
mt_collection, resource_id = self._get_collection_and_resource_id(
|
|
collection_name
|
|
)
|
|
dimension = len(items[0]["vector"])
|
|
self._ensure_collection(mt_collection, dimension)
|
|
collection = Collection(mt_collection)
|
|
|
|
entities = [
|
|
{
|
|
"id": item["id"],
|
|
"vector": item["vector"],
|
|
"text": item["text"],
|
|
"metadata": item["metadata"],
|
|
RESOURCE_ID_FIELD: resource_id,
|
|
}
|
|
for item in items
|
|
]
|
|
collection.insert(entities)
|
|
collection.flush()
|
|
|
|
def search(
|
|
self, collection_name: str, vectors: List[List[float]], limit: int
|
|
) -> Optional[SearchResult]:
|
|
if not vectors:
|
|
return None
|
|
|
|
mt_collection, resource_id = self._get_collection_and_resource_id(
|
|
collection_name
|
|
)
|
|
if not utility.has_collection(mt_collection):
|
|
return None
|
|
|
|
collection = Collection(mt_collection)
|
|
collection.load()
|
|
|
|
search_params = {"metric_type": MILVUS_METRIC_TYPE, "params": {}}
|
|
results = collection.search(
|
|
data=vectors,
|
|
anns_field="vector",
|
|
param=search_params,
|
|
limit=limit,
|
|
expr=f"{RESOURCE_ID_FIELD} == '{resource_id}'",
|
|
output_fields=["id", "text", "metadata"],
|
|
)
|
|
|
|
ids, documents, metadatas, distances = [], [], [], []
|
|
for hits in results:
|
|
batch_ids, batch_docs, batch_metadatas, batch_dists = [], [], [], []
|
|
for hit in hits:
|
|
batch_ids.append(hit.entity.get("id"))
|
|
batch_docs.append(hit.entity.get("text"))
|
|
batch_metadatas.append(hit.entity.get("metadata"))
|
|
batch_dists.append(hit.distance)
|
|
ids.append(batch_ids)
|
|
documents.append(batch_docs)
|
|
metadatas.append(batch_metadatas)
|
|
distances.append(batch_dists)
|
|
|
|
return SearchResult(
|
|
ids=ids, documents=documents, metadatas=metadatas, distances=distances
|
|
)
|
|
|
|
def delete(
|
|
self,
|
|
collection_name: str,
|
|
ids: Optional[List[str]] = None,
|
|
filter: Optional[Dict[str, Any]] = None,
|
|
):
|
|
mt_collection, resource_id = self._get_collection_and_resource_id(
|
|
collection_name
|
|
)
|
|
if not utility.has_collection(mt_collection):
|
|
return
|
|
|
|
collection = Collection(mt_collection)
|
|
|
|
# Build expression
|
|
expr = [f"{RESOURCE_ID_FIELD} == '{resource_id}'"]
|
|
if ids:
|
|
# Milvus expects a string list for 'in' operator
|
|
id_list_str = ", ".join([f"'{id_val}'" for id_val in ids])
|
|
expr.append(f"id in [{id_list_str}]")
|
|
|
|
if filter:
|
|
for key, value in filter.items():
|
|
expr.append(f"metadata['{key}'] == '{value}'")
|
|
|
|
collection.delete(" and ".join(expr))
|
|
|
|
def reset(self):
|
|
for collection_name in self.shared_collections:
|
|
if utility.has_collection(collection_name):
|
|
utility.drop_collection(collection_name)
|
|
|
|
def delete_collection(self, collection_name: str):
|
|
mt_collection, resource_id = self._get_collection_and_resource_id(
|
|
collection_name
|
|
)
|
|
if not utility.has_collection(mt_collection):
|
|
return
|
|
|
|
collection = Collection(mt_collection)
|
|
collection.delete(f"{RESOURCE_ID_FIELD} == '{resource_id}'")
|
|
|
|
def query(
|
|
self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
|
|
) -> Optional[GetResult]:
|
|
mt_collection, resource_id = self._get_collection_and_resource_id(
|
|
collection_name
|
|
)
|
|
if not utility.has_collection(mt_collection):
|
|
return None
|
|
|
|
collection = Collection(mt_collection)
|
|
collection.load()
|
|
|
|
expr = [f"{RESOURCE_ID_FIELD} == '{resource_id}'"]
|
|
if filter:
|
|
for key, value in filter.items():
|
|
if isinstance(value, str):
|
|
expr.append(f"metadata['{key}'] == '{value}'")
|
|
else:
|
|
expr.append(f"metadata['{key}'] == {value}")
|
|
|
|
results = collection.query(
|
|
expr=" and ".join(expr),
|
|
output_fields=["id", "text", "metadata"],
|
|
limit=limit,
|
|
)
|
|
|
|
ids = [res["id"] for res in results]
|
|
documents = [res["text"] for res in results]
|
|
metadatas = [res["metadata"] for res in results]
|
|
|
|
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
|
|
|
|
def get(self, collection_name: str) -> Optional[GetResult]:
|
|
return self.query(collection_name, filter={}, limit=None)
|
|
|
|
def insert(self, collection_name: str, items: List[VectorItem]):
|
|
return self.upsert(collection_name, items)
|