open-webui/backend/open_webui/retrieval/utils.py

981 lines
33 KiB
Python
Raw Normal View History

import logging
2024-08-28 06:10:27 +08:00
import os
from typing import Optional, Union
2024-03-09 11:26:39 +08:00
2024-08-28 06:10:27 +08:00
import requests
2025-02-27 15:51:39 +08:00
import hashlib
2025-03-31 22:43:37 +08:00
from concurrent.futures import ThreadPoolExecutor
2025-05-20 10:58:04 +08:00
import time
2024-09-10 09:27:50 +08:00
from urllib.parse import quote
2024-04-25 20:49:59 +08:00
from huggingface_hub import snapshot_download
2024-08-28 06:10:27 +08:00
from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
from langchain_community.retrievers import BM25Retriever
2024-08-28 06:10:27 +08:00
from langchain_core.documents import Document
2024-09-10 09:27:50 +08:00
2025-01-08 16:21:50 +08:00
from open_webui.config import VECTOR_DB
from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
2025-02-21 03:02:45 +08:00
from open_webui.models.users import UserModel
2025-02-27 07:42:19 +08:00
from open_webui.models.files import Files
2025-07-11 16:00:21 +08:00
from open_webui.models.knowledge import Knowledges
2025-07-09 05:17:25 +08:00
from open_webui.models.notes import Notes
2024-04-15 07:48:15 +08:00
2025-03-31 11:48:22 +08:00
from open_webui.retrieval.vector.main import GetResult
2025-07-11 16:00:21 +08:00
from open_webui.utils.access_control import has_access
2025-03-31 11:48:22 +08:00
2025-02-06 07:15:24 +08:00
2025-02-05 16:07:45 +08:00
from open_webui.env import (
SRC_LOG_LEVELS,
OFFLINE_MODE,
ENABLE_FORWARD_USER_INFO_HEADERS,
)
2025-02-05 05:04:36 +08:00
from open_webui.config import (
2025-03-31 12:55:15 +08:00
RAG_EMBEDDING_QUERY_PREFIX,
RAG_EMBEDDING_CONTENT_PREFIX,
RAG_EMBEDDING_PREFIX_FIELD_NAME,
2025-02-05 05:04:36 +08:00
)
2024-09-10 09:27:50 +08:00
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
2024-03-09 11:26:39 +08:00
2024-09-10 11:37:06 +08:00
from typing import Any
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.retrievers import BaseRetriever
class VectorSearchRetriever(BaseRetriever):
collection_name: Any
embedding_function: Any
top_k: int
def _get_relevant_documents(
self,
query: str,
2024-12-31 08:55:29 +08:00
*,
run_manager: CallbackManagerForRetrieverRun,
2024-09-10 11:37:06 +08:00
) -> list[Document]:
result = VECTOR_DB_CLIENT.search(
collection_name=self.collection_name,
2025-03-31 12:55:15 +08:00
vectors=[self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX)],
2024-09-10 11:37:06 +08:00
limit=self.top_k,
)
2024-09-13 13:18:20 +08:00
ids = result.ids[0]
metadatas = result.metadatas[0]
documents = result.documents[0]
2024-09-10 11:37:06 +08:00
2024-12-31 08:55:29 +08:00
results = []
for idx in range(len(ids)):
results.append(
Document(
metadata=metadatas[idx],
page_content=documents[idx],
)
)
return results
2024-09-10 11:37:06 +08:00
2024-04-28 03:38:50 +08:00
def query_doc(
2025-02-05 16:07:45 +08:00
collection_name: str, query_embedding: list[float], k: int, user: UserModel = None
2024-04-23 04:49:58 +08:00
):
2024-04-15 05:55:00 +08:00
try:
2025-04-09 00:08:32 +08:00
log.debug(f"query_doc:doc {collection_name}")
2024-12-31 08:55:29 +08:00
result = VECTOR_DB_CLIENT.search(
2024-09-10 11:37:06 +08:00
collection_name=collection_name,
vectors=[query_embedding],
2024-09-10 11:37:06 +08:00
limit=k,
2024-12-31 08:55:29 +08:00
)
if result:
2024-12-20 12:56:16 +08:00
log.info(f"query_doc:result {result.ids} {result.metadatas}")
2024-12-31 08:55:29 +08:00
return result
2024-04-28 03:38:50 +08:00
except Exception as e:
log.exception(f"Error querying doc {collection_name} with limit {k}: {e}")
2024-04-28 03:38:50 +08:00
raise e
2024-04-26 05:03:00 +08:00
2025-02-19 13:14:58 +08:00
def get_doc(collection_name: str, user: UserModel = None):
try:
2025-04-09 00:08:32 +08:00
log.debug(f"get_doc:doc {collection_name}")
2025-02-19 13:14:58 +08:00
result = VECTOR_DB_CLIENT.get(collection_name=collection_name)
if result:
log.info(f"query_doc:result {result.ids} {result.metadatas}")
return result
except Exception as e:
log.exception(f"Error getting doc {collection_name}: {e}")
2025-02-19 13:14:58 +08:00
raise e
2024-04-28 03:38:50 +08:00
def query_doc_with_hybrid_search(
collection_name: str,
2025-03-31 11:48:22 +08:00
collection_result: GetResult,
2024-04-28 03:38:50 +08:00
query: str,
embedding_function,
k: int,
reranking_function,
2025-03-06 17:47:57 +08:00
k_reranker: int,
r: float,
hybrid_bm25_weight: float,
2024-09-12 21:50:18 +08:00
) -> dict:
2024-04-28 03:38:50 +08:00
try:
2025-04-09 00:08:32 +08:00
log.debug(f"query_doc_with_hybrid_search:doc {collection_name}")
2024-04-28 03:38:50 +08:00
bm25_retriever = BM25Retriever.from_texts(
2025-03-31 11:48:22 +08:00
texts=collection_result.documents[0],
metadatas=collection_result.metadatas[0],
2024-04-28 03:38:50 +08:00
)
bm25_retriever.k = k
2024-04-26 05:03:00 +08:00
2024-09-10 11:37:06 +08:00
vector_search_retriever = VectorSearchRetriever(
collection_name=collection_name,
2024-04-28 03:38:50 +08:00
embedding_function=embedding_function,
2024-09-10 11:37:06 +08:00
top_k=k,
2024-04-28 03:38:50 +08:00
)
2024-04-26 05:03:00 +08:00
if hybrid_bm25_weight <= 0:
ensemble_retriever = EnsembleRetriever(
2025-05-24 06:13:54 +08:00
retrievers=[vector_search_retriever], weights=[1.0]
)
elif hybrid_bm25_weight >= 1:
ensemble_retriever = EnsembleRetriever(
2025-05-24 06:13:54 +08:00
retrievers=[bm25_retriever], weights=[1.0]
)
else:
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, vector_search_retriever],
2025-05-24 06:13:54 +08:00
weights=[hybrid_bm25_weight, 1.0 - hybrid_bm25_weight],
)
2024-04-28 03:38:50 +08:00
compressor = RerankCompressor(
embedding_function=embedding_function,
2025-03-06 17:47:57 +08:00
top_n=k_reranker,
2024-04-28 03:38:50 +08:00
reranking_function=reranking_function,
r_score=r,
)
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=ensemble_retriever
)
2024-04-26 05:03:00 +08:00
2024-04-28 03:38:50 +08:00
result = compression_retriever.invoke(query)
2025-03-31 11:48:22 +08:00
distances = [d.metadata.get("score") for d in result]
documents = [d.page_content for d in result]
metadatas = [d.metadata for d in result]
# retrieve only min(k, k_reranker) items, sort and cut by distance if k < k_reranker
if k < k_reranker:
2025-03-27 16:40:28 +08:00
sorted_items = sorted(
zip(distances, metadatas, documents), key=lambda x: x[0], reverse=True
)
sorted_items = sorted_items[:k]
distances, documents, metadatas = map(list, zip(*sorted_items))
2025-03-31 11:48:22 +08:00
result = {
"distances": [distances],
2025-03-18 19:14:59 +08:00
"documents": [documents],
"metadatas": [metadatas],
2024-04-28 03:38:50 +08:00
}
2024-04-30 01:15:58 +08:00
log.info(
2024-11-07 15:01:10 +08:00
"query_doc_with_hybrid_search:result "
2025-03-31 11:48:22 +08:00
+ f'{result["metadatas"]} {result["distances"]}'
)
2025-03-31 11:48:22 +08:00
return result
2024-04-15 05:55:00 +08:00
except Exception as e:
2025-04-09 00:08:32 +08:00
log.exception(f"Error querying doc {collection_name} with hybrid search: {e}")
2024-04-15 05:55:00 +08:00
raise e
2025-02-19 13:14:58 +08:00
def merge_get_results(get_results: list[dict]) -> dict:
# Initialize lists to store combined data
combined_documents = []
combined_metadatas = []
2025-02-19 15:49:27 +08:00
combined_ids = []
2025-02-19 13:14:58 +08:00
for data in get_results:
combined_documents.extend(data["documents"][0])
combined_metadatas.extend(data["metadatas"][0])
2025-02-19 15:49:27 +08:00
combined_ids.extend(data["ids"][0])
2025-02-19 13:14:58 +08:00
# Create the output dictionary
result = {
"documents": [combined_documents],
"metadatas": [combined_metadatas],
2025-02-19 15:49:27 +08:00
"ids": [combined_ids],
2025-02-19 13:14:58 +08:00
}
return result
2025-03-26 02:09:17 +08:00
def merge_and_sort_query_results(query_results: list[dict], k: int) -> dict:
2024-12-31 08:55:29 +08:00
# Initialize lists to store combined data
2025-03-19 23:06:10 +08:00
combined = dict() # To store documents with unique document hashes
2024-12-31 08:55:29 +08:00
for data in query_results:
2025-02-27 15:51:39 +08:00
distances = data["distances"][0]
documents = data["documents"][0]
metadatas = data["metadatas"][0]
for distance, document, metadata in zip(distances, documents, metadatas):
if isinstance(document, str):
doc_hash = hashlib.sha256(
2025-02-27 15:51:39 +08:00
document.encode()
).hexdigest() # Compute a hash for uniqueness
2024-12-31 08:55:29 +08:00
2025-03-19 23:06:10 +08:00
if doc_hash not in combined.keys():
combined[doc_hash] = (distance, document, metadata)
continue # if doc is new, no further comparison is needed
2024-12-31 08:55:29 +08:00
2025-03-19 23:06:10 +08:00
# if doc is alredy in, but new distance is better, update
if distance > combined[doc_hash][0]:
2025-03-19 23:06:10 +08:00
combined[doc_hash] = (distance, document, metadata)
combined = list(combined.values())
2024-12-31 08:55:29 +08:00
# Sort the list based on distances
combined.sort(key=lambda x: x[0], reverse=True)
2024-12-31 08:55:29 +08:00
2025-02-27 15:51:39 +08:00
# Slice to keep only the top k elements
sorted_distances, sorted_documents, sorted_metadatas = (
zip(*combined[:k]) if combined else ([], [], [])
)
2025-02-21 03:02:45 +08:00
2025-02-27 15:51:39 +08:00
# Create and return the output dictionary
return {
"distances": [list(sorted_distances)],
"documents": [list(sorted_documents)],
"metadatas": [list(sorted_metadatas)],
2024-12-31 08:55:29 +08:00
}
2024-03-09 11:26:39 +08:00
2025-02-19 13:14:58 +08:00
def get_all_items_from_collections(collection_names: list[str]) -> dict:
results = []
for collection_name in collection_names:
if collection_name:
try:
result = get_doc(collection_name=collection_name)
if result is not None:
results.append(result.model_dump())
except Exception as e:
log.exception(f"Error when querying the collection: {e}")
else:
pass
return merge_get_results(results)
2024-04-28 03:38:50 +08:00
def query_collection(
2024-08-14 20:46:31 +08:00
collection_names: list[str],
2024-11-19 18:24:32 +08:00
queries: list[str],
2024-04-28 03:38:50 +08:00
embedding_function,
k: int,
2024-09-12 21:50:18 +08:00
) -> dict:
2024-04-28 03:38:50 +08:00
results = []
error = False
def process_query_collection(collection_name, query_embedding):
try:
2024-12-31 08:55:29 +08:00
if collection_name:
result = query_doc(
collection_name=collection_name,
k=k,
query_embedding=query_embedding,
)
if result is not None:
return result.model_dump(), None
return None, None
except Exception as e:
log.exception(f"Error when querying the collection: {e}")
return None, e
# Generate all query embeddings (in one call)
query_embeddings = embedding_function(queries, prefix=RAG_EMBEDDING_QUERY_PREFIX)
log.debug(
f"query_collection: processing {len(queries)} queries across {len(collection_names)} collections"
)
with ThreadPoolExecutor() as executor:
future_results = []
for query_embedding in query_embeddings:
for collection_name in collection_names:
result = executor.submit(
process_query_collection, collection_name, query_embedding
)
future_results.append(result)
task_results = [future.result() for future in future_results]
for result, err in task_results:
if err is not None:
error = True
elif result is not None:
results.append(result)
if error and not results:
log.warning("All collection queries failed. No results returned.")
return merge_and_sort_query_results(results, k=k)
2024-04-28 03:38:50 +08:00
def query_collection_with_hybrid_search(
2024-08-14 20:46:31 +08:00
collection_names: list[str],
2024-11-19 18:24:32 +08:00
queries: list[str],
2024-04-28 03:38:50 +08:00
embedding_function,
2024-04-23 04:49:58 +08:00
k: int,
reranking_function,
2025-03-06 17:47:57 +08:00
k_reranker: int,
2024-04-28 03:38:50 +08:00
r: float,
hybrid_bm25_weight: float,
2024-09-12 21:50:18 +08:00
) -> dict:
2024-04-15 05:55:00 +08:00
results = []
2024-09-13 13:18:20 +08:00
error = False
2025-03-28 02:05:20 +08:00
# Fetch collection data once per collection sequentially
# Avoid fetching the same data multiple times later
2025-03-31 11:48:22 +08:00
collection_results = {}
2025-03-28 02:05:20 +08:00
for collection_name in collection_names:
try:
2025-04-13 07:35:11 +08:00
log.debug(
f"query_collection_with_hybrid_search:VECTOR_DB_CLIENT.get:collection {collection_name}"
)
2025-03-31 11:48:22 +08:00
collection_results[collection_name] = VECTOR_DB_CLIENT.get(
collection_name=collection_name
)
2025-03-28 02:05:20 +08:00
except Exception as e:
log.exception(f"Failed to fetch collection {collection_name}: {e}")
2025-03-31 11:48:22 +08:00
collection_results[collection_name] = None
2025-03-28 02:05:20 +08:00
2025-04-01 08:59:21 +08:00
log.info(
f"Starting hybrid search for {len(queries)} queries in {len(collection_names)} collections..."
)
2025-03-31 22:43:37 +08:00
def process_query(collection_name, query):
2024-12-31 08:55:29 +08:00
try:
2025-03-31 22:43:37 +08:00
result = query_doc_with_hybrid_search(
collection_name=collection_name,
collection_result=collection_results[collection_name],
query=query,
embedding_function=embedding_function,
k=k,
reranking_function=reranking_function,
k_reranker=k_reranker,
r=r,
hybrid_bm25_weight=hybrid_bm25_weight,
2024-12-31 08:55:29 +08:00
)
2025-03-31 22:43:37 +08:00
return result, None
except Exception as e:
log.exception(f"Error when querying the collection with hybrid_search: {e}")
return None, e
# Prepare tasks for all collections and queries
# Avoid running any tasks for collections that failed to fetch data (have assigned None)
2025-04-05 22:03:24 +08:00
tasks = [
(cn, q)
for cn in collection_names
if collection_results[cn] is not None
for q in queries
]
2025-03-31 22:43:37 +08:00
with ThreadPoolExecutor() as executor:
future_results = [executor.submit(process_query, cn, q) for cn, q in tasks]
task_results = [future.result() for future in future_results]
for result, err in task_results:
if err is not None:
2024-12-31 08:55:29 +08:00
error = True
2025-03-31 22:43:37 +08:00
elif result is not None:
results.append(result)
2024-09-13 13:18:20 +08:00
2025-03-31 22:43:37 +08:00
if error and not results:
2025-04-01 08:59:21 +08:00
raise Exception(
"Hybrid search failed for all collections. Using Non-hybrid search as fallback."
)
return merge_and_sort_query_results(results, k=k)
2024-04-15 05:55:00 +08:00
2025-03-27 16:40:28 +08:00
2024-04-28 03:38:50 +08:00
def get_embedding_function(
2024-04-23 04:49:58 +08:00
embedding_engine,
embedding_model,
embedding_function,
2024-11-19 06:19:56 +08:00
url,
key,
2025-02-05 16:07:45 +08:00
embedding_batch_size,
azure_api_version=None,
2024-04-23 04:49:58 +08:00
):
if embedding_engine == "":
2025-04-01 05:13:27 +08:00
return lambda query, prefix=None, user=None: embedding_function.encode(
2025-04-07 08:17:24 +08:00
query, **({"prompt": prefix} if prefix else {})
2025-03-31 12:55:15 +08:00
).tolist()
2025-05-20 10:58:04 +08:00
elif embedding_engine in ["ollama", "openai", "azure_openai"]:
2025-04-01 05:13:27 +08:00
func = lambda query, prefix=None, user=None: generate_embeddings(
2024-10-10 03:05:16 +08:00
engine=embedding_engine,
model=embedding_model,
text=query,
prefix=prefix,
2024-11-19 06:19:56 +08:00
url=url,
key=key,
2025-02-05 16:07:45 +08:00
user=user,
azure_api_version=azure_api_version,
2024-10-10 03:05:16 +08:00
)
2025-03-31 12:55:15 +08:00
2025-02-06 07:15:24 +08:00
def generate_multiple(query, prefix, user, func):
if isinstance(query, list):
2024-12-31 08:55:29 +08:00
embeddings = []
for i in range(0, len(query), embedding_batch_size):
2025-02-05 16:07:45 +08:00
embeddings.extend(
2025-03-31 12:55:15 +08:00
func(
query[i : i + embedding_batch_size],
prefix=prefix,
user=user,
)
2025-02-05 16:07:45 +08:00
)
2024-12-31 08:55:29 +08:00
return embeddings
else:
2025-02-06 07:15:24 +08:00
return func(query, prefix, user)
2025-03-31 12:55:15 +08:00
2025-04-01 05:13:27 +08:00
return lambda query, prefix=None, user=None: generate_multiple(
2025-03-31 12:55:15 +08:00
query, prefix, user, func
)
else:
raise ValueError(f"Unknown embedding engine: {embedding_engine}")
2024-04-23 04:49:58 +08:00
def get_reranking_function(reranking_engine, reranking_model, reranking_function):
2025-07-14 18:05:06 +08:00
if reranking_function is None:
return None
if reranking_engine == "external":
return lambda sentences, user=None: reranking_function.predict(
sentences, user=user
)
else:
return lambda sentences, user=None: reranking_function.predict(sentences)
2025-07-11 16:00:21 +08:00
def get_sources_from_items(
2025-02-27 07:42:19 +08:00
request,
2025-07-11 16:00:21 +08:00
items,
2024-11-19 18:24:32 +08:00
queries,
2024-04-28 03:38:50 +08:00
embedding_function,
2024-04-15 07:48:15 +08:00
k,
2024-04-28 03:38:50 +08:00
reranking_function,
2025-03-06 17:47:57 +08:00
k_reranker,
r,
hybrid_bm25_weight,
2024-04-27 02:41:39 +08:00
hybrid_search,
2025-02-19 13:14:58 +08:00
full_context=False,
2025-07-11 16:00:21 +08:00
user: Optional[UserModel] = None,
2024-04-15 07:48:15 +08:00
):
2025-02-19 13:14:58 +08:00
log.debug(
2025-07-11 16:00:21 +08:00
f"items: {items} {queries} {embedding_function} {reranking_function} {full_context}"
2025-02-19 13:14:58 +08:00
)
2024-03-11 09:40:50 +08:00
extracted_collections = []
2025-06-25 16:20:08 +08:00
query_results = []
2024-03-11 09:40:50 +08:00
2025-07-11 16:00:21 +08:00
for item in items:
2025-06-25 16:20:08 +08:00
query_result = None
2025-07-11 16:29:17 +08:00
collection_names = []
2025-07-11 16:00:21 +08:00
if item.get("type") == "text":
2025-07-11 16:29:17 +08:00
# Raw Text
2025-07-11 16:00:21 +08:00
# Used during temporary chat file uploads
2025-07-11 16:35:42 +08:00
if item.get("file"):
# if item has file data, use it
query_result = {
2025-07-14 21:50:03 +08:00
"documents": [
[item.get("file", {}).get("data", {}).get("content")]
],
"metadatas": [
[item.get("file", {}).get("data", {}).get("meta", {})]
],
2025-07-11 16:35:42 +08:00
}
else:
# Fallback to item content
query_result = {
"documents": [[item.get("content")]],
"metadatas": [
[{"file_id": item.get("id"), "name": item.get("name")}]
],
}
2025-07-11 16:00:21 +08:00
elif item.get("type") == "note":
2025-07-09 05:17:25 +08:00
# Note Attached
2025-07-11 16:00:21 +08:00
note = Notes.get_note_by_id(item.get("id"))
2025-07-09 05:17:25 +08:00
2025-07-22 15:38:47 +08:00
if note and (
2025-07-22 21:17:26 +08:00
user.role == "admin"
or note.user_id == user.id
or has_access(user.id, "read", note.access_control)
2025-07-22 15:38:47 +08:00
):
2025-07-11 16:00:21 +08:00
# User has access to the note
query_result = {
"documents": [[note.data.get("content", {}).get("md", "")]],
"metadatas": [[{"file_id": note.id, "name": note.title}]],
}
2025-07-11 16:29:17 +08:00
elif item.get("type") == "file":
if (
item.get("context") == "full"
or request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
):
2025-07-14 21:50:03 +08:00
if item.get("file", {}).get("data", {}).get("content", ""):
2025-07-11 16:29:17 +08:00
# Manual Full Mode Toggle
# Used from chat file modal, we can assume that the file content will be available from item.get("file").get("data", {}).get("content")
query_result = {
"documents": [
2025-07-14 21:50:03 +08:00
[item.get("file", {}).get("data", {}).get("content", "")]
2025-07-11 16:29:17 +08:00
],
"metadatas": [
[
{
"file_id": item.get("id"),
"name": item.get("name"),
**item.get("file")
.get("data", {})
.get("metadata", {}),
}
]
],
}
elif item.get("id"):
file_object = Files.get_file_by_id(item.get("id"))
if file_object:
query_result = {
"documents": [[file_object.data.get("content", "")]],
"metadatas": [
[
{
"file_id": item.get("id"),
"name": file_object.filename,
"source": file_object.filename,
}
]
],
}
else:
# Fallback to collection names
if item.get("legacy"):
collection_names.append(f"{item['id']}")
else:
collection_names.append(f"file-{item['id']}")
2025-07-11 16:00:21 +08:00
2025-07-11 16:29:17 +08:00
elif item.get("type") == "collection":
if (
item.get("context") == "full"
or request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
):
# Manual Full Mode Toggle for Collection
2025-07-11 16:00:21 +08:00
knowledge_base = Knowledges.get_knowledge_by_id(item.get("id"))
2025-07-11 16:00:21 +08:00
if knowledge_base and (
user.role == "admin"
or has_access(user.id, "read", knowledge_base.access_control)
):
2025-07-11 16:00:21 +08:00
file_ids = knowledge_base.data.get("file_ids", [])
2025-07-11 16:00:21 +08:00
documents = []
metadatas = []
for file_id in file_ids:
file_object = Files.get_file_by_id(file_id)
if file_object:
documents.append(file_object.data.get("content", ""))
metadatas.append(
{
"file_id": file_id,
"name": file_object.filename,
"source": file_object.filename,
}
)
query_result = {
"documents": [documents],
"metadatas": [metadatas],
}
2025-07-11 16:29:17 +08:00
else:
# Fallback to collection names
2025-07-11 16:00:21 +08:00
if item.get("legacy"):
collection_names = item.get("collection_names", [])
2024-10-04 14:06:47 +08:00
else:
2025-07-11 16:00:21 +08:00
collection_names.append(item["id"])
2024-05-07 06:49:00 +08:00
2025-07-11 16:29:17 +08:00
elif item.get("docs"):
# BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
query_result = {
"documents": [[doc.get("content") for doc in item.get("docs")]],
"metadatas": [[doc.get("metadata") for doc in item.get("docs")]],
}
elif item.get("collection_name"):
# Direct Collection Name
collection_names.append(item["collection_name"])
2025-07-16 01:57:24 +08:00
elif item.get("collection_names"):
# Collection Names List
collection_names.extend(item["collection_names"])
2025-07-11 16:29:17 +08:00
# If query_result is None
# Fallback to collection names and vector search the collections
if query_result is None and collection_names:
2024-09-30 04:52:27 +08:00
collection_names = set(collection_names).difference(extracted_collections)
if not collection_names:
2025-07-11 16:00:21 +08:00
log.debug(f"skipping {item} as it has already been extracted")
2024-09-30 04:52:27 +08:00
continue
2024-04-15 07:48:15 +08:00
2025-07-11 16:35:42 +08:00
try:
if full_context:
2025-06-25 16:20:08 +08:00
query_result = get_all_items_from_collections(collection_names)
2025-07-11 16:35:42 +08:00
else:
query_result = None # Initialize to None
if hybrid_search:
try:
query_result = query_collection_with_hybrid_search(
2024-09-30 04:52:27 +08:00
collection_names=collection_names,
2024-11-19 18:24:32 +08:00
queries=queries,
2024-09-30 04:52:27 +08:00
embedding_function=embedding_function,
k=k,
2025-07-11 16:35:42 +08:00
reranking_function=reranking_function,
k_reranker=k_reranker,
r=r,
hybrid_bm25_weight=hybrid_bm25_weight,
)
except Exception as e:
log.debug(
"Error when using hybrid search, using non hybrid search as fallback."
2024-09-30 04:52:27 +08:00
)
2025-07-11 16:35:42 +08:00
# fallback to non-hybrid search
if not hybrid_search and query_result is None:
query_result = query_collection(
collection_names=collection_names,
queries=queries,
embedding_function=embedding_function,
k=k,
)
except Exception as e:
log.exception(e)
2024-09-30 04:52:27 +08:00
extracted_collections.extend(collection_names)
2024-03-11 09:40:50 +08:00
2025-06-25 16:20:08 +08:00
if query_result:
2025-07-11 16:00:21 +08:00
if "data" in item:
del item["data"]
query_results.append({**query_result, "file": item})
2024-03-11 09:40:50 +08:00
2024-11-22 11:46:09 +08:00
sources = []
2025-06-25 16:20:08 +08:00
for query_result in query_results:
try:
2025-06-25 16:20:08 +08:00
if "documents" in query_result:
if "metadatas" in query_result:
2024-11-22 11:46:09 +08:00
source = {
2025-06-25 16:20:08 +08:00
"source": query_result["file"],
"document": query_result["documents"][0],
"metadata": query_result["metadatas"][0],
}
2025-06-25 16:20:08 +08:00
if "distances" in query_result and query_result["distances"]:
source["distances"] = query_result["distances"][0]
2024-11-22 11:46:09 +08:00
sources.append(source)
except Exception as e:
log.exception(e)
2024-05-07 06:14:33 +08:00
2024-11-22 11:46:09 +08:00
return sources
2024-04-05 02:07:42 +08:00
2024-04-25 20:49:59 +08:00
def get_model_path(model: str, update_model: bool = False):
# Construct huggingface_hub kwargs with local_files_only to return the snapshot path
cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME")
local_files_only = not update_model
if OFFLINE_MODE:
local_files_only = True
2024-04-25 20:49:59 +08:00
snapshot_kwargs = {
"cache_dir": cache_dir,
"local_files_only": local_files_only,
}
2024-04-26 02:28:31 +08:00
log.debug(f"model: {model}")
2024-04-25 20:49:59 +08:00
log.debug(f"snapshot_kwargs: {snapshot_kwargs}")
# Inspiration from upstream sentence_transformers
if (
2024-07-15 17:09:05 +08:00
os.path.exists(model)
2024-04-25 20:49:59 +08:00
or ("\\" in model or model.count("/") > 1)
and local_files_only
):
# If fully qualified path exists, return input, else set repo_id
2024-07-15 17:09:05 +08:00
return model
2024-04-25 20:49:59 +08:00
elif "/" not in model:
# Set valid repo_id for model short-name
model = "sentence-transformers" + "/" + model
snapshot_kwargs["repo_id"] = model
# Attempt to query the huggingface_hub library to determine the local path and/or to update
try:
model_repo_path = snapshot_download(**snapshot_kwargs)
log.debug(f"model_repo_path: {model_repo_path}")
return model_repo_path
except Exception as e:
log.exception(f"Cannot determine model snapshot path: {e}")
2024-07-15 17:09:05 +08:00
return model
2024-04-25 20:49:59 +08:00
2024-11-19 06:19:56 +08:00
def generate_openai_batch_embeddings(
2025-02-05 16:07:45 +08:00
model: str,
texts: list[str],
url: str = "https://api.openai.com/v1",
key: str = "",
2025-02-06 07:15:24 +08:00
prefix: str = None,
2025-03-31 12:55:15 +08:00
user: UserModel = None,
) -> Optional[list[list[float]]]:
2024-04-15 07:15:39 +08:00
try:
2025-04-13 07:35:11 +08:00
log.debug(
f"generate_openai_batch_embeddings:model {model} batch size: {len(texts)}"
)
2025-03-31 12:55:15 +08:00
json_data = {"input": texts, "model": model}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
2025-02-06 06:03:16 +08:00
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
2024-04-15 07:15:39 +08:00
r = requests.post(
2024-04-21 04:15:59 +08:00
f"{url}/embeddings",
2024-04-15 07:15:39 +08:00
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {key}",
**(
{
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
2024-04-15 07:15:39 +08:00
},
2025-02-06 06:03:16 +08:00
json=json_data,
2024-04-15 07:15:39 +08:00
)
r.raise_for_status()
data = r.json()
if "data" in data:
return [elem["embedding"] for elem in data["data"]]
2024-04-15 07:15:39 +08:00
else:
raise "Something went wrong :/"
except Exception as e:
log.exception(f"Error generating openai batch embeddings: {e}")
2024-04-15 07:15:39 +08:00
return None
2024-04-23 04:49:58 +08:00
2025-05-20 10:58:04 +08:00
def generate_azure_openai_batch_embeddings(
model: str,
2025-05-20 10:58:04 +08:00
texts: list[str],
url: str,
key: str = "",
version: str = "",
prefix: str = None,
user: UserModel = None,
) -> Optional[list[list[float]]]:
try:
log.debug(
f"generate_azure_openai_batch_embeddings:deployment {model} batch size: {len(texts)}"
2025-05-20 10:58:04 +08:00
)
json_data = {"input": texts}
2025-05-20 10:58:04 +08:00
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
url = f"{url}/openai/deployments/{model}/embeddings?api-version={version}"
2025-05-20 10:58:04 +08:00
for _ in range(5):
r = requests.post(
url,
headers={
"Content-Type": "application/json",
"api-key": key,
**(
{
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
2025-05-20 10:58:04 +08:00
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
json=json_data,
)
if r.status_code == 429:
retry = float(r.headers.get("Retry-After", "1"))
time.sleep(retry)
continue
r.raise_for_status()
data = r.json()
if "data" in data:
return [elem["embedding"] for elem in data["data"]]
else:
raise Exception("Something went wrong :/")
return None
except Exception as e:
log.exception(f"Error generating azure openai batch embeddings: {e}")
return None
2024-11-19 06:19:56 +08:00
def generate_ollama_batch_embeddings(
2025-03-31 12:55:15 +08:00
model: str,
2025-02-06 07:15:24 +08:00
texts: list[str],
url: str,
2025-03-31 12:55:15 +08:00
key: str = "",
prefix: str = None,
user: UserModel = None,
2024-11-19 06:19:56 +08:00
) -> Optional[list[list[float]]]:
try:
2025-04-13 07:35:11 +08:00
log.debug(
f"generate_ollama_batch_embeddings:model {model} batch size: {len(texts)}"
)
2025-03-31 12:55:15 +08:00
json_data = {"input": texts, "model": model}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
2025-02-06 06:03:16 +08:00
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
2025-03-31 12:55:15 +08:00
2024-12-31 08:55:29 +08:00
r = requests.post(
f"{url}/api/embed",
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {key}",
**(
{
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
),
2024-12-31 08:55:29 +08:00
},
2025-02-06 06:03:16 +08:00
json=json_data,
2024-12-31 08:55:29 +08:00
)
2024-11-19 06:19:56 +08:00
r.raise_for_status()
2024-12-31 08:55:29 +08:00
data = r.json()
if "embeddings" in data:
return data["embeddings"]
else:
raise "Something went wrong :/"
2024-11-19 06:19:56 +08:00
except Exception as e:
log.exception(f"Error generating ollama batch embeddings: {e}")
2024-11-19 06:19:56 +08:00
return None
2025-03-31 12:55:15 +08:00
def generate_embeddings(
engine: str,
model: str,
text: Union[str, list[str]],
prefix: Union[str, None] = None,
**kwargs,
):
2024-11-19 06:19:56 +08:00
url = kwargs.get("url", "")
key = kwargs.get("key", "")
user = kwargs.get("user")
2024-11-19 06:19:56 +08:00
2025-02-06 06:03:16 +08:00
if prefix is not None and RAG_EMBEDDING_PREFIX_FIELD_NAME is None:
if isinstance(text, list):
2025-03-31 12:55:15 +08:00
text = [f"{prefix}{text_element}" for text_element in text]
2025-02-06 06:03:16 +08:00
else:
2025-03-31 12:55:15 +08:00
text = f"{prefix}{text}"
2025-02-06 06:03:16 +08:00
2024-10-10 02:41:35 +08:00
if engine == "ollama":
2025-05-30 05:19:56 +08:00
embeddings = generate_ollama_batch_embeddings(
**{
"model": model,
"texts": text if isinstance(text, list) else [text],
"url": url,
"key": key,
"prefix": prefix,
"user": user,
}
)
2024-11-19 06:19:56 +08:00
return embeddings[0] if isinstance(text, str) else embeddings
2024-10-10 02:41:35 +08:00
elif engine == "openai":
2025-05-30 05:19:56 +08:00
embeddings = generate_openai_batch_embeddings(
model, text if isinstance(text, list) else [text], url, key, prefix, user
)
2024-10-10 02:41:35 +08:00
return embeddings[0] if isinstance(text, str) else embeddings
2025-05-20 10:58:04 +08:00
elif engine == "azure_openai":
azure_api_version = kwargs.get("azure_api_version", "")
2025-05-30 05:19:56 +08:00
embeddings = generate_azure_openai_batch_embeddings(
model,
text if isinstance(text, list) else [text],
url,
key,
azure_api_version,
prefix,
user,
)
2025-05-20 10:58:04 +08:00
return embeddings[0] if isinstance(text, str) else embeddings
import operator
from typing import Optional, Sequence
from langchain_core.callbacks import Callbacks
2024-08-28 06:10:27 +08:00
from langchain_core.documents import BaseDocumentCompressor, Document
class RerankCompressor(BaseDocumentCompressor):
2024-04-28 03:38:50 +08:00
embedding_function: Any
2024-04-30 01:15:58 +08:00
top_n: int
reranking_function: Any
r_score: float
class Config:
2024-09-19 23:05:49 +08:00
extra = "forbid"
arbitrary_types_allowed = True
def compress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
2024-04-30 01:15:58 +08:00
reranking = self.reranking_function is not None
if reranking:
scores = self.reranking_function(
[(query, doc.page_content) for doc in documents]
)
else:
from sentence_transformers import util
query_embedding = self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX)
2024-04-28 03:38:50 +08:00
document_embedding = self.embedding_function(
2025-03-31 12:55:15 +08:00
[doc.page_content for doc in documents], RAG_EMBEDDING_CONTENT_PREFIX
)
scores = util.cos_sim(query_embedding, document_embedding)[0]
2025-05-10 23:33:34 +08:00
docs_with_scores = list(
zip(documents, scores.tolist() if not isinstance(scores, list) else scores)
)
if self.r_score:
docs_with_scores = [
(d, s) for d, s in docs_with_scores if s >= self.r_score
]
2024-04-30 01:15:58 +08:00
result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)
final_results = []
for doc, doc_score in result[: self.top_n]:
metadata = doc.metadata
metadata["score"] = doc_score
doc = Document(
page_content=doc.page_content,
metadata=metadata,
)
final_results.append(doc)
return final_results