2025-05-08 09:51:28 +08:00
import os # Added import
2024-09-12 13:52:19 +08:00
from pymilvus import MilvusClient as Client
from pymilvus import FieldSchema , DataType
import json
2025-02-25 22:36:25 +08:00
import logging
2024-09-10 11:37:06 +08:00
from typing import Optional
feat: Add abstract base class for vector database integration
- Created `VectorDBBase` as an abstract base class to standardize vector database operations.
- Added required methods for common vector database operations: `has_collection`, `delete_collection`, `insert`, `upsert`, `search`, `query`, `get`, `delete`, `reset`.
- The base class can now be extended by any vector database implementation (e.g., Qdrant, Pinecone) to ensure a consistent API across different database systems.
2025-04-21 13:26:08 +08:00
from open_webui . retrieval . vector . main import (
VectorDBBase ,
VectorItem ,
SearchResult ,
GetResult ,
)
2024-09-12 13:52:19 +08:00
from open_webui . config import (
2025-01-20 03:59:07 +08:00
MILVUS_URI ,
MILVUS_DB ,
2025-02-01 06:00:38 +08:00
MILVUS_TOKEN ,
2024-09-12 13:52:19 +08:00
)
2025-02-25 22:36:25 +08:00
from open_webui . env import SRC_LOG_LEVELS
log = logging . getLogger ( __name__ )
log . setLevel ( SRC_LOG_LEVELS [ " RAG " ] )
2024-09-10 11:37:06 +08:00
feat: Add abstract base class for vector database integration
- Created `VectorDBBase` as an abstract base class to standardize vector database operations.
- Added required methods for common vector database operations: `has_collection`, `delete_collection`, `insert`, `upsert`, `search`, `query`, `get`, `delete`, `reset`.
- The base class can now be extended by any vector database implementation (e.g., Qdrant, Pinecone) to ensure a consistent API across different database systems.
2025-04-21 13:26:08 +08:00
class MilvusClient ( VectorDBBase ) :
2024-09-10 11:37:06 +08:00
def __init__ ( self ) :
2024-09-12 13:52:19 +08:00
self . collection_prefix = " open_webui "
2025-02-01 06:00:38 +08:00
if MILVUS_TOKEN is None :
2025-03-04 16:02:19 +08:00
self . client = Client ( uri = MILVUS_URI , db_name = MILVUS_DB )
2025-02-01 06:00:38 +08:00
else :
2025-03-04 16:02:19 +08:00
self . client = Client ( uri = MILVUS_URI , db_name = MILVUS_DB , token = MILVUS_TOKEN )
2024-09-10 11:37:06 +08:00
2024-09-13 13:30:30 +08:00
def _result_to_get_result ( self , result ) - > GetResult :
ids = [ ]
documents = [ ]
metadatas = [ ]
for match in result :
_ids = [ ]
_documents = [ ]
_metadatas = [ ]
for item in match :
_ids . append ( item . get ( " id " ) )
_documents . append ( item . get ( " data " , { } ) . get ( " text " ) )
_metadatas . append ( item . get ( " metadata " ) )
ids . append ( _ids )
documents . append ( _documents )
metadatas . append ( _metadatas )
return GetResult (
* * {
" ids " : ids ,
" documents " : documents ,
" metadatas " : metadatas ,
}
)
def _result_to_search_result ( self , result ) - > SearchResult :
2024-09-12 13:52:19 +08:00
ids = [ ]
distances = [ ]
documents = [ ]
metadatas = [ ]
for match in result :
_ids = [ ]
_distances = [ ]
_documents = [ ]
_metadatas = [ ]
for item in match :
_ids . append ( item . get ( " id " ) )
2025-03-25 23:46:14 +08:00
# normalize milvus score from [-1, 1] to [0, 1] range
# https://milvus.io/docs/de/metric.md
2025-03-26 02:09:17 +08:00
_dist = ( item . get ( " distance " ) + 1.0 ) / 2.0
2025-03-25 23:46:14 +08:00
_distances . append ( _dist )
2024-09-12 13:52:19 +08:00
_documents . append ( item . get ( " entity " , { } ) . get ( " data " , { } ) . get ( " text " ) )
_metadatas . append ( item . get ( " entity " , { } ) . get ( " metadata " ) )
ids . append ( _ids )
distances . append ( _distances )
documents . append ( _documents )
metadatas . append ( _metadatas )
2024-09-13 13:18:20 +08:00
return SearchResult (
* * {
" ids " : ids ,
" distances " : distances ,
" documents " : documents ,
" metadatas " : metadatas ,
}
)
2024-09-12 13:52:19 +08:00
def _create_collection ( self , collection_name : str , dimension : int ) :
schema = self . client . create_schema (
auto_id = False ,
enable_dynamic_field = True ,
)
schema . add_field (
field_name = " id " ,
datatype = DataType . VARCHAR ,
is_primary = True ,
max_length = 65535 ,
)
schema . add_field (
field_name = " vector " ,
datatype = DataType . FLOAT_VECTOR ,
dim = dimension ,
description = " vector " ,
)
schema . add_field ( field_name = " data " , datatype = DataType . JSON , description = " data " )
schema . add_field (
field_name = " metadata " , datatype = DataType . JSON , description = " metadata "
)
index_params = self . client . prepare_index_params ( )
2025-05-08 09:51:28 +08:00
# Get index type from environment variable.
# Milvus standalone (local mode) supports: FLAT, IVF_FLAT, AUTOINDEX.
# HNSW is often preferred for performance but may require a clustered Milvus setup.
# Defaulting to AUTOINDEX for broader compatibility, especially with Milvus standalone.
default_index_type = " AUTOINDEX "
milvus_index_type_env = os . getenv ( " MILVUS_INDEX_TYPE " )
if milvus_index_type_env :
milvus_index_type = milvus_index_type_env . upper ( )
log . info ( f " Milvus index type from MILVUS_INDEX_TYPE env var: { milvus_index_type } " )
else :
milvus_index_type = default_index_type
log . info ( f " MILVUS_INDEX_TYPE env var not set, defaulting to: { milvus_index_type } " )
index_creation_params = { }
metric_type = os . getenv ( " MILVUS_METRIC_TYPE " , " COSINE " ) . upper ( ) # Default to COSINE
if milvus_index_type == " HNSW " :
# Parameters for HNSW
m_env = os . getenv ( " MILVUS_HNSW_M " , " 16 " )
ef_construction_env = os . getenv ( " MILVUS_HNSW_EFCONSTRUCTION " , " 100 " )
try :
m_val = int ( m_env )
ef_val = int ( ef_construction_env )
except ValueError :
log . warning ( f " Invalid HNSW params M= ' { m_env } ' or efConstruction= ' { ef_construction_env } ' . Defaulting to M=16, efConstruction=100. " )
m_val = 16
ef_val = 100
index_creation_params = { " M " : m_val , " efConstruction " : ef_val }
log . info ( f " Using HNSW index with metric { metric_type } , params: { index_creation_params } " )
elif milvus_index_type == " IVF_FLAT " :
# Parameters for IVF_FLAT
nlist_env = os . getenv ( " MILVUS_IVF_FLAT_NLIST " , " 128 " )
try :
nlist = int ( nlist_env )
except ValueError :
log . warning ( f " Invalid MILVUS_IVF_FLAT_NLIST value ' { nlist_env } ' . Defaulting to 128. " )
nlist = 128
index_creation_params = { " nlist " : nlist }
log . info ( f " Using IVF_FLAT index with metric { metric_type } , params: { index_creation_params } " )
elif milvus_index_type == " FLAT " :
log . info ( f " Using FLAT index with metric { metric_type } (no specific build-time params). " )
# No specific build-time parameters needed for FLAT
elif milvus_index_type == " AUTOINDEX " :
log . info ( f " Using AUTOINDEX with metric { metric_type } (params managed by Milvus). " )
# No specific build-time parameters needed for AUTOINDEX
else :
log . warning (
f " Unsupported or unrecognized MILVUS_INDEX_TYPE: ' { milvus_index_type } ' . "
f " Falling back to ' { default_index_type } ' . "
f " Supported types: HNSW, IVF_FLAT, FLAT, AUTOINDEX. "
)
milvus_index_type = default_index_type # Fallback to a safe default
# index_creation_params remains {} which is fine for AUTOINDEX/FLAT
log . info ( f " Fell back to { default_index_type } index with metric { metric_type } . " )
2024-09-12 13:52:19 +08:00
index_params . add_index (
2024-09-27 04:59:09 +08:00
field_name = " vector " ,
2025-05-08 09:51:28 +08:00
index_type = milvus_index_type ,
metric_type = metric_type ,
params = index_creation_params ,
2024-09-12 13:52:19 +08:00
)
2024-09-10 11:37:06 +08:00
2024-09-12 13:52:19 +08:00
self . client . create_collection (
collection_name = f " { self . collection_prefix } _ { collection_name } " ,
schema = schema ,
index_params = index_params ,
)
2025-05-08 09:51:28 +08:00
log . info ( f " Successfully created collection ' { self . collection_prefix } _ { collection_name } ' with index type ' { milvus_index_type } ' and metric ' { metric_type } ' . " )
2024-09-12 13:52:19 +08:00
2024-09-12 14:00:31 +08:00
def has_collection ( self , collection_name : str ) - > bool :
# Check if the collection exists based on the collection name.
2024-10-07 08:58:09 +08:00
collection_name = collection_name . replace ( " - " , " _ " )
2024-09-12 14:00:31 +08:00
return self . client . has_collection (
collection_name = f " { self . collection_prefix } _ { collection_name } "
)
2024-09-10 11:37:06 +08:00
def delete_collection ( self , collection_name : str ) :
2024-09-12 13:52:19 +08:00
# Delete the collection based on the collection name.
2024-10-07 08:58:09 +08:00
collection_name = collection_name . replace ( " - " , " _ " )
2024-09-12 13:52:19 +08:00
return self . client . drop_collection (
collection_name = f " { self . collection_prefix } _ { collection_name } "
)
2024-09-10 11:37:06 +08:00
def search (
self , collection_name : str , vectors : list [ list [ float | int ] ] , limit : int
2024-09-13 13:18:20 +08:00
) - > Optional [ SearchResult ] :
2024-09-12 13:52:19 +08:00
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
2024-10-07 08:58:09 +08:00
collection_name = collection_name . replace ( " - " , " _ " )
2025-05-08 09:51:28 +08:00
# For some index types like IVF_FLAT, search params like nprobe can be set.
# Example: search_params = {"nprobe": 10} if using IVF_FLAT
# For simplicity, not adding configurable search_params here, but could be extended.
2024-09-12 13:52:19 +08:00
result = self . client . search (
collection_name = f " { self . collection_prefix } _ { collection_name } " ,
data = vectors ,
limit = limit ,
output_fields = [ " data " , " metadata " ] ,
2025-05-08 09:51:28 +08:00
# search_params=search_params # Potentially add later if needed
2024-09-12 13:52:19 +08:00
)
2024-09-13 13:30:30 +08:00
return self . _result_to_search_result ( result )
2024-09-10 11:37:06 +08:00
2024-10-06 00:58:46 +08:00
def query ( self , collection_name : str , filter : dict , limit : Optional [ int ] = None ) :
# Construct the filter string for querying
2024-10-07 08:58:09 +08:00
collection_name = collection_name . replace ( " - " , " _ " )
if not self . has_collection ( collection_name ) :
2025-05-08 09:51:28 +08:00
log . warning ( f " Query attempted on non-existent collection: { self . collection_prefix } _ { collection_name } " )
2024-10-07 08:58:09 +08:00
return None
2024-10-03 21:53:21 +08:00
filter_string = " && " . join (
[
2024-10-07 08:58:09 +08:00
f ' metadata[ " { key } " ] == { json . dumps ( value ) } '
2024-10-03 21:53:21 +08:00
for key , value in filter . items ( )
]
)
2024-10-06 00:58:46 +08:00
max_limit = 16383 # The maximum number of records per request
all_results = [ ]
if limit is None :
2025-05-08 09:51:28 +08:00
# Milvus default limit for query if not specified is 16384, but docs mention iteration.
# Let's set a practical high number if "all" is intended, or handle true pagination.
# For now, if limit is None, we'll fetch in batches up to a very large number.
# This part could be refined based on expected use cases for "get all".
# For this function signature, None implies "as many as possible" up to Milvus limits.
limit = 16384 * 10 # A large number to signify fetching many, will be capped by actual data or max_limit per call.
log . info ( f " Limit not specified for query, fetching up to { limit } results in batches. " )
2024-10-06 00:58:46 +08:00
# Initialize offset and remaining to handle pagination
offset = 0
remaining = limit
2025-05-08 09:51:28 +08:00
2024-10-07 08:58:09 +08:00
try :
2025-05-08 09:51:28 +08:00
log . info ( f " Querying collection { self . collection_prefix } _ { collection_name } with filter: ' { filter_string } ' , limit: { limit } " )
2024-10-07 08:58:09 +08:00
# Loop until there are no more items to fetch or the desired limit is reached
while remaining > 0 :
2025-05-08 09:51:28 +08:00
current_fetch = min ( max_limit , remaining if isinstance ( remaining , int ) else max_limit )
log . debug ( f " Querying with offset: { offset } , current_fetch: { current_fetch } " )
2024-10-07 08:58:09 +08:00
results = self . client . query (
collection_name = f " { self . collection_prefix } _ { collection_name } " ,
filter = filter_string ,
2025-05-08 09:51:28 +08:00
output_fields = [ " id " , " data " , " metadata " ] , # Explicitly list needed fields. Vector not usually needed in query.
2024-10-07 08:58:09 +08:00
limit = current_fetch ,
offset = offset ,
)
2025-05-08 09:51:28 +08:00
2024-10-07 08:58:09 +08:00
if not results :
2025-05-08 09:51:28 +08:00
log . debug ( " No more results from query. " )
2024-10-07 08:58:09 +08:00
break
2025-05-08 09:51:28 +08:00
2024-10-07 08:58:09 +08:00
all_results . extend ( results )
results_count = len ( results )
2025-05-08 09:51:28 +08:00
log . debug ( f " Fetched { results_count } results in this batch. " )
2024-10-07 08:58:09 +08:00
2025-05-08 09:51:28 +08:00
if isinstance ( remaining , int ) :
remaining - = results_count
offset + = results_count
# Break the loop if the results returned are less than the requested fetch count (means end of data)
2024-10-07 08:58:09 +08:00
if results_count < current_fetch :
2025-05-08 09:51:28 +08:00
log . debug ( " Fetched less than requested, assuming end of results for this query. " )
2024-10-07 08:58:09 +08:00
break
2025-05-08 09:51:28 +08:00
log . info ( f " Total results from query: { len ( all_results ) } " )
2024-10-07 08:58:09 +08:00
return self . _result_to_get_result ( [ all_results ] )
except Exception as e :
2025-02-27 14:18:18 +08:00
log . exception (
2025-05-08 09:51:28 +08:00
f " Error querying collection { self . collection_prefix } _ { collection_name } with filter ' { filter_string } ' and limit { limit } : { e } "
2025-02-27 14:18:18 +08:00
)
2024-10-07 08:58:09 +08:00
return None
2024-10-03 21:53:21 +08:00
2024-09-13 13:18:20 +08:00
def get ( self , collection_name : str ) - > Optional [ GetResult ] :
2025-05-08 09:51:28 +08:00
# Get all the items in the collection. This can be very resource-intensive for large collections.
2024-10-07 08:58:09 +08:00
collection_name = collection_name . replace ( " - " , " _ " )
2025-05-08 09:51:28 +08:00
log . warning ( f " Fetching ALL items from collection ' { self . collection_prefix } _ { collection_name } ' . This might be slow for large collections. " )
# Using query with a trivial filter to get all items.
# This will use the paginated query logic.
return self . query ( collection_name = collection_name , filter = { } , limit = None )
2024-09-10 11:37:06 +08:00
def insert ( self , collection_name : str , items : list [ VectorItem ] ) :
2024-09-12 13:52:19 +08:00
# Insert the items into the collection, if the collection does not exist, it will be created.
2024-10-07 08:58:09 +08:00
collection_name = collection_name . replace ( " - " , " _ " )
2024-09-12 13:52:19 +08:00
if not self . client . has_collection (
collection_name = f " { self . collection_prefix } _ { collection_name } "
) :
2025-05-08 09:51:28 +08:00
log . info ( f " Collection { self . collection_prefix } _ { collection_name } does not exist. Creating now. " )
if not items :
log . error ( f " Cannot create collection { self . collection_prefix } _ { collection_name } without items to determine dimension. " )
raise ValueError ( " Cannot create Milvus collection without items to determine vector dimension. " )
2024-09-12 13:52:19 +08:00
self . _create_collection (
collection_name = collection_name , dimension = len ( items [ 0 ] [ " vector " ] )
)
2025-05-08 09:51:28 +08:00
log . info ( f " Inserting { len ( items ) } items into collection { self . collection_prefix } _ { collection_name } . " )
2024-09-12 13:52:19 +08:00
return self . client . insert (
collection_name = f " { self . collection_prefix } _ { collection_name } " ,
data = [
{
" id " : item [ " id " ] ,
" vector " : item [ " vector " ] ,
" data " : { " text " : item [ " text " ] } ,
" metadata " : item [ " metadata " ] ,
}
for item in items
] ,
)
2024-09-10 11:37:06 +08:00
def upsert ( self , collection_name : str , items : list [ VectorItem ] ) :
2024-09-12 13:52:19 +08:00
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
2024-10-07 08:58:09 +08:00
collection_name = collection_name . replace ( " - " , " _ " )
2024-09-12 13:52:19 +08:00
if not self . client . has_collection (
collection_name = f " { self . collection_prefix } _ { collection_name } "
) :
2025-05-08 09:51:28 +08:00
log . info ( f " Collection { self . collection_prefix } _ { collection_name } does not exist for upsert. Creating now. " )
if not items :
log . error ( f " Cannot create collection { self . collection_prefix } _ { collection_name } for upsert without items to determine dimension. " )
raise ValueError ( " Cannot create Milvus collection for upsert without items to determine vector dimension. " )
2024-09-12 13:52:19 +08:00
self . _create_collection (
collection_name = collection_name , dimension = len ( items [ 0 ] [ " vector " ] )
)
2025-05-08 09:51:28 +08:00
log . info ( f " Upserting { len ( items ) } items into collection { self . collection_prefix } _ { collection_name } . " )
2024-09-12 13:52:19 +08:00
return self . client . upsert (
collection_name = f " { self . collection_prefix } _ { collection_name } " ,
data = [
{
" id " : item [ " id " ] ,
" vector " : item [ " vector " ] ,
" data " : { " text " : item [ " text " ] } ,
" metadata " : item [ " metadata " ] ,
}
for item in items
] ,
)
2024-09-10 11:37:06 +08:00
2024-10-03 21:43:50 +08:00
def delete (
self ,
collection_name : str ,
ids : Optional [ list [ str ] ] = None ,
filter : Optional [ dict ] = None ,
) :
2025-05-08 09:51:28 +08:00
# Delete the items from the collection based on the ids or filter.
2024-10-07 08:58:09 +08:00
collection_name = collection_name . replace ( " - " , " _ " )
2025-05-08 09:51:28 +08:00
if not self . has_collection ( collection_name ) :
log . warning ( f " Delete attempted on non-existent collection: { self . collection_prefix } _ { collection_name } " )
return None
2024-10-03 21:43:50 +08:00
if ids :
2025-05-08 09:51:28 +08:00
log . info ( f " Deleting items by IDs from { self . collection_prefix } _ { collection_name } . IDs: { ids } " )
2024-10-03 21:43:50 +08:00
return self . client . delete (
collection_name = f " { self . collection_prefix } _ { collection_name } " ,
ids = ids ,
)
elif filter :
filter_string = " && " . join (
[
2024-10-07 08:58:09 +08:00
f ' metadata[ " { key } " ] == { json . dumps ( value ) } '
2024-10-03 21:43:50 +08:00
for key , value in filter . items ( )
]
)
2025-05-08 09:51:28 +08:00
log . info ( f " Deleting items by filter from { self . collection_prefix } _ { collection_name } . Filter: { filter_string } " )
2024-10-03 21:43:50 +08:00
return self . client . delete (
collection_name = f " { self . collection_prefix } _ { collection_name } " ,
filter = filter_string ,
)
2025-05-08 09:51:28 +08:00
else :
log . warning ( f " Delete operation on { self . collection_prefix } _ { collection_name } called without IDs or filter. No action taken. " )
return None
2024-09-10 11:37:06 +08:00
def reset ( self ) :
2025-05-08 09:51:28 +08:00
# Resets the database. This will delete all collections and item entries that match the prefix.
log . warning ( f " Resetting Milvus: Deleting all collections with prefix ' { self . collection_prefix } ' . " )
2024-09-12 13:52:19 +08:00
collection_names = self . client . list_collections ( )
2025-05-08 09:51:28 +08:00
deleted_collections = [ ]
for collection_name_full in collection_names :
if collection_name_full . startswith ( self . collection_prefix ) :
try :
self . client . drop_collection ( collection_name = collection_name_full )
deleted_collections . append ( collection_name_full )
log . info ( f " Deleted collection: { collection_name_full } " )
except Exception as e :
log . error ( f " Error deleting collection { collection_name_full } : { e } " )
log . info ( f " Milvus reset complete. Deleted collections: { deleted_collections } " )