open-webui/backend/open_webui/retrieval/utils.py

856 lines
28 KiB
Python
Raw Normal View History

import logging
2024-08-28 06:10:27 +08:00
import os
from typing import Optional, Union
2024-03-09 11:26:39 +08:00
2024-08-28 06:10:27 +08:00
import requests
2025-02-27 15:51:39 +08:00
import hashlib
2025-03-31 22:43:37 +08:00
from concurrent.futures import ThreadPoolExecutor
2024-09-10 09:27:50 +08:00
2024-04-25 20:49:59 +08:00
from huggingface_hub import snapshot_download
2024-08-28 06:10:27 +08:00
from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
from langchain_community.retrievers import BM25Retriever
2024-08-28 06:10:27 +08:00
from langchain_core.documents import Document
2024-09-10 09:27:50 +08:00
2025-01-08 16:21:50 +08:00
from open_webui.config import VECTOR_DB
from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
2025-02-21 03:02:45 +08:00
from open_webui.models.users import UserModel
2025-02-27 07:42:19 +08:00
from open_webui.models.files import Files
2024-04-15 07:48:15 +08:00
2025-03-31 11:48:22 +08:00
from open_webui.retrieval.vector.main import GetResult
2025-02-06 07:15:24 +08:00
2025-02-05 16:07:45 +08:00
from open_webui.env import (
SRC_LOG_LEVELS,
OFFLINE_MODE,
ENABLE_FORWARD_USER_INFO_HEADERS,
)
2025-02-05 05:04:36 +08:00
from open_webui.config import (
2025-03-31 12:55:15 +08:00
RAG_EMBEDDING_QUERY_PREFIX,
RAG_EMBEDDING_CONTENT_PREFIX,
RAG_EMBEDDING_PREFIX_FIELD_NAME,
2025-02-05 05:04:36 +08:00
)
2024-09-10 09:27:50 +08:00
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
2024-03-09 11:26:39 +08:00
2024-09-10 11:37:06 +08:00
from typing import Any
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.retrievers import BaseRetriever
class VectorSearchRetriever(BaseRetriever):
collection_name: Any
embedding_function: Any
top_k: int
def _get_relevant_documents(
self,
query: str,
2024-12-31 08:55:29 +08:00
*,
run_manager: CallbackManagerForRetrieverRun,
2024-09-10 11:37:06 +08:00
) -> list[Document]:
result = VECTOR_DB_CLIENT.search(
collection_name=self.collection_name,
2025-03-31 12:55:15 +08:00
vectors=[self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX)],
2024-09-10 11:37:06 +08:00
limit=self.top_k,
)
2024-09-13 13:18:20 +08:00
ids = result.ids[0]
metadatas = result.metadatas[0]
documents = result.documents[0]
2024-09-10 11:37:06 +08:00
2024-12-31 08:55:29 +08:00
results = []
for idx in range(len(ids)):
results.append(
Document(
metadata=metadatas[idx],
page_content=documents[idx],
)
)
return results
2024-09-10 11:37:06 +08:00
2024-04-28 03:38:50 +08:00
def query_doc(
2025-02-05 16:07:45 +08:00
collection_name: str, query_embedding: list[float], k: int, user: UserModel = None
2024-04-23 04:49:58 +08:00
):
2024-04-15 05:55:00 +08:00
try:
2025-04-09 00:08:32 +08:00
log.debug(f"query_doc:doc {collection_name}")
2024-12-31 08:55:29 +08:00
result = VECTOR_DB_CLIENT.search(
2024-09-10 11:37:06 +08:00
collection_name=collection_name,
vectors=[query_embedding],
2024-09-10 11:37:06 +08:00
limit=k,
2024-12-31 08:55:29 +08:00
)
if result:
2024-12-20 12:56:16 +08:00
log.info(f"query_doc:result {result.ids} {result.metadatas}")
2024-12-31 08:55:29 +08:00
return result
2024-04-28 03:38:50 +08:00
except Exception as e:
log.exception(f"Error querying doc {collection_name} with limit {k}: {e}")
2024-04-28 03:38:50 +08:00
raise e
2024-04-26 05:03:00 +08:00
2025-02-19 13:14:58 +08:00
def get_doc(collection_name: str, user: UserModel = None):
try:
2025-04-09 00:08:32 +08:00
log.debug(f"get_doc:doc {collection_name}")
2025-02-19 13:14:58 +08:00
result = VECTOR_DB_CLIENT.get(collection_name=collection_name)
if result:
log.info(f"query_doc:result {result.ids} {result.metadatas}")
return result
except Exception as e:
log.exception(f"Error getting doc {collection_name}: {e}")
2025-02-19 13:14:58 +08:00
raise e
2024-04-28 03:38:50 +08:00
def query_doc_with_hybrid_search(
collection_name: str,
2025-03-31 11:48:22 +08:00
collection_result: GetResult,
2024-04-28 03:38:50 +08:00
query: str,
embedding_function,
k: int,
reranking_function,
2025-03-06 17:47:57 +08:00
k_reranker: int,
r: float,
hybrid_bm25_weight: float,
2024-09-12 21:50:18 +08:00
) -> dict:
2024-04-28 03:38:50 +08:00
try:
2025-04-09 00:08:32 +08:00
log.debug(f"query_doc_with_hybrid_search:doc {collection_name}")
2024-04-28 03:38:50 +08:00
bm25_retriever = BM25Retriever.from_texts(
2025-03-31 11:48:22 +08:00
texts=collection_result.documents[0],
metadatas=collection_result.metadatas[0],
2024-04-28 03:38:50 +08:00
)
bm25_retriever.k = k
2024-04-26 05:03:00 +08:00
2024-09-10 11:37:06 +08:00
vector_search_retriever = VectorSearchRetriever(
collection_name=collection_name,
2024-04-28 03:38:50 +08:00
embedding_function=embedding_function,
2024-09-10 11:37:06 +08:00
top_k=k,
2024-04-28 03:38:50 +08:00
)
2024-04-26 05:03:00 +08:00
if hybrid_bm25_weight <= 0:
ensemble_retriever = EnsembleRetriever(
retrievers=[vector_search_retriever], weights=[1.]
)
elif hybrid_bm25_weight >= 1:
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever], weights=[1.]
)
else:
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, vector_search_retriever],
weights=[hybrid_bm25_weight, 1. - hybrid_bm25_weight]
)
2024-04-28 03:38:50 +08:00
compressor = RerankCompressor(
embedding_function=embedding_function,
2025-03-06 17:47:57 +08:00
top_n=k_reranker,
2024-04-28 03:38:50 +08:00
reranking_function=reranking_function,
r_score=r,
)
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=ensemble_retriever
)
2024-04-26 05:03:00 +08:00
2024-04-28 03:38:50 +08:00
result = compression_retriever.invoke(query)
2025-03-31 11:48:22 +08:00
distances = [d.metadata.get("score") for d in result]
documents = [d.page_content for d in result]
metadatas = [d.metadata for d in result]
# retrieve only min(k, k_reranker) items, sort and cut by distance if k < k_reranker
if k < k_reranker:
2025-03-27 16:40:28 +08:00
sorted_items = sorted(
zip(distances, metadatas, documents), key=lambda x: x[0], reverse=True
)
sorted_items = sorted_items[:k]
distances, documents, metadatas = map(list, zip(*sorted_items))
2025-03-31 11:48:22 +08:00
result = {
"distances": [distances],
2025-03-18 19:14:59 +08:00
"documents": [documents],
"metadatas": [metadatas],
2024-04-28 03:38:50 +08:00
}
2024-04-30 01:15:58 +08:00
log.info(
2024-11-07 15:01:10 +08:00
"query_doc_with_hybrid_search:result "
2025-03-31 11:48:22 +08:00
+ f'{result["metadatas"]} {result["distances"]}'
)
2025-03-31 11:48:22 +08:00
return result
2024-04-15 05:55:00 +08:00
except Exception as e:
2025-04-09 00:08:32 +08:00
log.exception(f"Error querying doc {collection_name} with hybrid search: {e}")
2024-04-15 05:55:00 +08:00
raise e
2025-02-19 13:14:58 +08:00
def merge_get_results(get_results: list[dict]) -> dict:
# Initialize lists to store combined data
combined_documents = []
combined_metadatas = []
2025-02-19 15:49:27 +08:00
combined_ids = []
2025-02-19 13:14:58 +08:00
for data in get_results:
combined_documents.extend(data["documents"][0])
combined_metadatas.extend(data["metadatas"][0])
2025-02-19 15:49:27 +08:00
combined_ids.extend(data["ids"][0])
2025-02-19 13:14:58 +08:00
# Create the output dictionary
result = {
"documents": [combined_documents],
"metadatas": [combined_metadatas],
2025-02-19 15:49:27 +08:00
"ids": [combined_ids],
2025-02-19 13:14:58 +08:00
}
return result
2025-03-26 02:09:17 +08:00
def merge_and_sort_query_results(query_results: list[dict], k: int) -> dict:
2024-12-31 08:55:29 +08:00
# Initialize lists to store combined data
2025-03-19 23:06:10 +08:00
combined = dict() # To store documents with unique document hashes
2024-12-31 08:55:29 +08:00
for data in query_results:
2025-02-27 15:51:39 +08:00
distances = data["distances"][0]
documents = data["documents"][0]
metadatas = data["metadatas"][0]
for distance, document, metadata in zip(distances, documents, metadatas):
if isinstance(document, str):
doc_hash = hashlib.sha256(
2025-02-27 15:51:39 +08:00
document.encode()
).hexdigest() # Compute a hash for uniqueness
2024-12-31 08:55:29 +08:00
2025-03-19 23:06:10 +08:00
if doc_hash not in combined.keys():
combined[doc_hash] = (distance, document, metadata)
continue # if doc is new, no further comparison is needed
2024-12-31 08:55:29 +08:00
2025-03-19 23:06:10 +08:00
# if doc is alredy in, but new distance is better, update
if distance > combined[doc_hash][0]:
2025-03-19 23:06:10 +08:00
combined[doc_hash] = (distance, document, metadata)
combined = list(combined.values())
2024-12-31 08:55:29 +08:00
# Sort the list based on distances
combined.sort(key=lambda x: x[0], reverse=True)
2024-12-31 08:55:29 +08:00
2025-02-27 15:51:39 +08:00
# Slice to keep only the top k elements
sorted_distances, sorted_documents, sorted_metadatas = (
zip(*combined[:k]) if combined else ([], [], [])
)
2025-02-21 03:02:45 +08:00
2025-02-27 15:51:39 +08:00
# Create and return the output dictionary
return {
"distances": [list(sorted_distances)],
"documents": [list(sorted_documents)],
"metadatas": [list(sorted_metadatas)],
2024-12-31 08:55:29 +08:00
}
2024-03-09 11:26:39 +08:00
2025-02-19 13:14:58 +08:00
def get_all_items_from_collections(collection_names: list[str]) -> dict:
results = []
for collection_name in collection_names:
if collection_name:
try:
result = get_doc(collection_name=collection_name)
if result is not None:
results.append(result.model_dump())
except Exception as e:
log.exception(f"Error when querying the collection: {e}")
else:
pass
return merge_get_results(results)
2024-04-28 03:38:50 +08:00
def query_collection(
2024-08-14 20:46:31 +08:00
collection_names: list[str],
2024-11-19 18:24:32 +08:00
queries: list[str],
2024-04-28 03:38:50 +08:00
embedding_function,
k: int,
2024-09-12 21:50:18 +08:00
) -> dict:
2024-04-28 03:38:50 +08:00
results = []
error = False
def process_query_collection(collection_name, query_embedding):
try:
2024-12-31 08:55:29 +08:00
if collection_name:
result = query_doc(
collection_name=collection_name,
k=k,
query_embedding=query_embedding,
)
if result is not None:
return result.model_dump(), None
return None, None
except Exception as e:
log.exception(f"Error when querying the collection: {e}")
return None, e
# Generate all query embeddings (in one call)
query_embeddings = embedding_function(queries, prefix=RAG_EMBEDDING_QUERY_PREFIX)
log.debug(
f"query_collection: processing {len(queries)} queries across {len(collection_names)} collections"
)
with ThreadPoolExecutor() as executor:
future_results = []
for query_embedding in query_embeddings:
for collection_name in collection_names:
result = executor.submit(
process_query_collection, collection_name, query_embedding
)
future_results.append(result)
task_results = [future.result() for future in future_results]
for result, err in task_results:
if err is not None:
error = True
elif result is not None:
results.append(result)
if error and not results:
log.warning("All collection queries failed. No results returned.")
return merge_and_sort_query_results(results, k=k)
2024-04-28 03:38:50 +08:00
def query_collection_with_hybrid_search(
2024-08-14 20:46:31 +08:00
collection_names: list[str],
2024-11-19 18:24:32 +08:00
queries: list[str],
2024-04-28 03:38:50 +08:00
embedding_function,
2024-04-23 04:49:58 +08:00
k: int,
reranking_function,
2025-03-06 17:47:57 +08:00
k_reranker: int,
2024-04-28 03:38:50 +08:00
r: float,
hybrid_bm25_weight: float,
2024-09-12 21:50:18 +08:00
) -> dict:
2024-04-15 05:55:00 +08:00
results = []
2024-09-13 13:18:20 +08:00
error = False
2025-03-28 02:05:20 +08:00
# Fetch collection data once per collection sequentially
# Avoid fetching the same data multiple times later
2025-03-31 11:48:22 +08:00
collection_results = {}
2025-03-28 02:05:20 +08:00
for collection_name in collection_names:
try:
2025-04-13 07:35:11 +08:00
log.debug(
f"query_collection_with_hybrid_search:VECTOR_DB_CLIENT.get:collection {collection_name}"
)
2025-03-31 11:48:22 +08:00
collection_results[collection_name] = VECTOR_DB_CLIENT.get(
collection_name=collection_name
)
2025-03-28 02:05:20 +08:00
except Exception as e:
log.exception(f"Failed to fetch collection {collection_name}: {e}")
2025-03-31 11:48:22 +08:00
collection_results[collection_name] = None
2025-03-28 02:05:20 +08:00
2025-04-01 08:59:21 +08:00
log.info(
f"Starting hybrid search for {len(queries)} queries in {len(collection_names)} collections..."
)
2025-03-31 22:43:37 +08:00
def process_query(collection_name, query):
2024-12-31 08:55:29 +08:00
try:
2025-03-31 22:43:37 +08:00
result = query_doc_with_hybrid_search(
collection_name=collection_name,
collection_result=collection_results[collection_name],
query=query,
embedding_function=embedding_function,
k=k,
reranking_function=reranking_function,
k_reranker=k_reranker,
r=r,
hybrid_bm25_weight=hybrid_bm25_weight,
2024-12-31 08:55:29 +08:00
)
2025-03-31 22:43:37 +08:00
return result, None
except Exception as e:
log.exception(f"Error when querying the collection with hybrid_search: {e}")
return None, e
# Prepare tasks for all collections and queries
# Avoid running any tasks for collections that failed to fetch data (have assigned None)
2025-04-05 22:03:24 +08:00
tasks = [
(cn, q)
for cn in collection_names
if collection_results[cn] is not None
for q in queries
]
2025-03-31 22:43:37 +08:00
with ThreadPoolExecutor() as executor:
future_results = [executor.submit(process_query, cn, q) for cn, q in tasks]
task_results = [future.result() for future in future_results]
for result, err in task_results:
if err is not None:
2024-12-31 08:55:29 +08:00
error = True
2025-03-31 22:43:37 +08:00
elif result is not None:
results.append(result)
2024-09-13 13:18:20 +08:00
2025-03-31 22:43:37 +08:00
if error and not results:
2025-04-01 08:59:21 +08:00
raise Exception(
"Hybrid search failed for all collections. Using Non-hybrid search as fallback."
)
return merge_and_sort_query_results(results, k=k)
2024-04-15 05:55:00 +08:00
2025-03-27 16:40:28 +08:00
2024-04-28 03:38:50 +08:00
def get_embedding_function(
2024-04-23 04:49:58 +08:00
embedding_engine,
embedding_model,
embedding_function,
2024-11-19 06:19:56 +08:00
url,
key,
2025-02-05 16:07:45 +08:00
embedding_batch_size,
2024-04-23 04:49:58 +08:00
):
if embedding_engine == "":
2025-04-01 05:13:27 +08:00
return lambda query, prefix=None, user=None: embedding_function.encode(
2025-04-07 08:17:24 +08:00
query, **({"prompt": prefix} if prefix else {})
2025-03-31 12:55:15 +08:00
).tolist()
elif embedding_engine in ["ollama", "openai"]:
2025-04-01 05:13:27 +08:00
func = lambda query, prefix=None, user=None: generate_embeddings(
2024-10-10 03:05:16 +08:00
engine=embedding_engine,
model=embedding_model,
text=query,
prefix=prefix,
2024-11-19 06:19:56 +08:00
url=url,
key=key,
2025-02-05 16:07:45 +08:00
user=user,
2024-10-10 03:05:16 +08:00
)
2025-03-31 12:55:15 +08:00
2025-02-06 07:15:24 +08:00
def generate_multiple(query, prefix, user, func):
if isinstance(query, list):
2024-12-31 08:55:29 +08:00
embeddings = []
for i in range(0, len(query), embedding_batch_size):
2025-02-05 16:07:45 +08:00
embeddings.extend(
2025-03-31 12:55:15 +08:00
func(
query[i : i + embedding_batch_size],
prefix=prefix,
user=user,
)
2025-02-05 16:07:45 +08:00
)
2024-12-31 08:55:29 +08:00
return embeddings
else:
2025-02-06 07:15:24 +08:00
return func(query, prefix, user)
2025-03-31 12:55:15 +08:00
2025-04-01 05:13:27 +08:00
return lambda query, prefix=None, user=None: generate_multiple(
2025-03-31 12:55:15 +08:00
query, prefix, user, func
)
else:
raise ValueError(f"Unknown embedding engine: {embedding_engine}")
2024-04-23 04:49:58 +08:00
2024-11-22 11:46:09 +08:00
def get_sources_from_files(
2025-02-27 07:42:19 +08:00
request,
2024-06-19 05:55:18 +08:00
files,
2024-11-19 18:24:32 +08:00
queries,
2024-04-28 03:38:50 +08:00
embedding_function,
2024-04-15 07:48:15 +08:00
k,
2024-04-28 03:38:50 +08:00
reranking_function,
2025-03-06 17:47:57 +08:00
k_reranker,
r,
hybrid_bm25_weight,
2024-04-27 02:41:39 +08:00
hybrid_search,
2025-02-19 13:14:58 +08:00
full_context=False,
2024-04-15 07:48:15 +08:00
):
2025-02-19 13:14:58 +08:00
log.debug(
f"files: {files} {queries} {embedding_function} {reranking_function} {full_context}"
)
2024-03-11 09:40:50 +08:00
extracted_collections = []
2024-03-11 09:40:50 +08:00
relevant_contexts = []
2024-06-19 05:55:18 +08:00
for file in files:
2025-02-27 07:42:19 +08:00
context = None
2025-02-18 10:14:26 +08:00
if file.get("docs"):
2025-02-27 07:42:19 +08:00
# BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
2025-02-18 10:14:26 +08:00
context = {
"documents": [[doc.get("content") for doc in file.get("docs")]],
"metadatas": [[doc.get("metadata") for doc in file.get("docs")]],
}
elif file.get("context") == "full":
2025-02-27 07:42:19 +08:00
# Manual Full Mode Toggle
2024-09-30 04:52:27 +08:00
context = {
2024-10-04 13:22:22 +08:00
"documents": [[file.get("file").get("data", {}).get("content")]],
2024-09-30 04:55:53 +08:00
"metadatas": [[{"file_id": file.get("id"), "name": file.get("name")}]],
2024-09-30 04:52:27 +08:00
}
2025-02-27 07:42:19 +08:00
elif (
file.get("type") != "web_search"
and request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
):
# BYPASS_EMBEDDING_AND_RETRIEVAL
if file.get("type") == "collection":
file_ids = file.get("data", {}).get("file_ids", [])
documents = []
metadatas = []
for file_id in file_ids:
file_object = Files.get_file_by_id(file_id)
if file_object:
documents.append(file_object.data.get("content", ""))
metadatas.append(
{
"file_id": file_id,
"name": file_object.filename,
"source": file_object.filename,
}
)
context = {
"documents": [documents],
"metadatas": [metadatas],
}
2024-03-11 09:40:50 +08:00
2025-02-27 07:42:19 +08:00
elif file.get("id"):
file_object = Files.get_file_by_id(file.get("id"))
if file_object:
context = {
"documents": [[file_object.data.get("content", "")]],
"metadatas": [
[
{
"file_id": file.get("id"),
"name": file_object.filename,
"source": file_object.filename,
}
]
],
}
elif file.get("file").get("data"):
context = {
"documents": [[file.get("file").get("data", {}).get("content")]],
"metadatas": [
[file.get("file").get("data", {}).get("metadata", {})]
],
}
2025-02-27 07:42:19 +08:00
else:
2024-10-04 14:06:47 +08:00
collection_names = []
if file.get("type") == "collection":
if file.get("legacy"):
collection_names = file.get("collection_names", [])
else:
collection_names.append(file["id"])
elif file.get("collection_name"):
collection_names.append(file["collection_name"])
elif file.get("id"):
2024-10-04 15:59:19 +08:00
if file.get("legacy"):
collection_names.append(f"{file['id']}")
else:
collection_names.append(f"file-{file['id']}")
2024-05-07 06:49:00 +08:00
2024-09-30 04:52:27 +08:00
collection_names = set(collection_names).difference(extracted_collections)
if not collection_names:
log.debug(f"skipping {file} as it has already been extracted")
continue
2024-04-15 07:48:15 +08:00
2025-02-19 13:14:58 +08:00
if full_context:
try:
context = get_all_items_from_collections(collection_names)
except Exception as e:
log.exception(e)
else:
try:
context = None
if file.get("type") == "text":
context = file["content"]
else:
if hybrid_search:
try:
context = query_collection_with_hybrid_search(
collection_names=collection_names,
queries=queries,
embedding_function=embedding_function,
k=k,
reranking_function=reranking_function,
2025-03-06 17:47:57 +08:00
k_reranker=k_reranker,
2025-02-19 13:14:58 +08:00
r=r,
hybrid_bm25_weight=hybrid_bm25_weight,
2025-02-19 13:14:58 +08:00
)
except Exception as e:
log.debug(
"Error when using hybrid search, using"
" non hybrid search as fallback."
)
if (not hybrid_search) or (context is None):
context = query_collection(
2024-09-30 04:52:27 +08:00
collection_names=collection_names,
2024-11-19 18:24:32 +08:00
queries=queries,
2024-09-30 04:52:27 +08:00
embedding_function=embedding_function,
k=k,
)
2025-02-19 13:14:58 +08:00
except Exception as e:
log.exception(e)
2024-09-30 04:52:27 +08:00
extracted_collections.extend(collection_names)
2024-03-11 09:40:50 +08:00
if context:
2024-10-18 04:08:10 +08:00
if "data" in file:
del file["data"]
2025-02-27 07:42:19 +08:00
2024-09-30 04:52:27 +08:00
relevant_contexts.append({**context, "file": file})
2024-03-11 09:40:50 +08:00
2024-11-22 11:46:09 +08:00
sources = []
2024-03-11 09:40:50 +08:00
for context in relevant_contexts:
try:
if "documents" in context:
if "metadatas" in context:
2024-11-22 11:46:09 +08:00
source = {
"source": context["file"],
"document": context["documents"][0],
"metadata": context["metadatas"][0],
}
if "distances" in context and context["distances"]:
2024-11-22 11:46:09 +08:00
source["distances"] = context["distances"][0]
sources.append(source)
except Exception as e:
log.exception(e)
2024-05-07 06:14:33 +08:00
2024-11-22 11:46:09 +08:00
return sources
2024-04-05 02:07:42 +08:00
2024-04-25 20:49:59 +08:00
def get_model_path(model: str, update_model: bool = False):
# Construct huggingface_hub kwargs with local_files_only to return the snapshot path
cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME")
local_files_only = not update_model
if OFFLINE_MODE:
local_files_only = True
2024-04-25 20:49:59 +08:00
snapshot_kwargs = {
"cache_dir": cache_dir,
"local_files_only": local_files_only,
}
2024-04-26 02:28:31 +08:00
log.debug(f"model: {model}")
2024-04-25 20:49:59 +08:00
log.debug(f"snapshot_kwargs: {snapshot_kwargs}")
# Inspiration from upstream sentence_transformers
if (
2024-07-15 17:09:05 +08:00
os.path.exists(model)
2024-04-25 20:49:59 +08:00
or ("\\" in model or model.count("/") > 1)
and local_files_only
):
# If fully qualified path exists, return input, else set repo_id
2024-07-15 17:09:05 +08:00
return model
2024-04-25 20:49:59 +08:00
elif "/" not in model:
# Set valid repo_id for model short-name
model = "sentence-transformers" + "/" + model
snapshot_kwargs["repo_id"] = model
# Attempt to query the huggingface_hub library to determine the local path and/or to update
try:
model_repo_path = snapshot_download(**snapshot_kwargs)
log.debug(f"model_repo_path: {model_repo_path}")
return model_repo_path
except Exception as e:
log.exception(f"Cannot determine model snapshot path: {e}")
2024-07-15 17:09:05 +08:00
return model
2024-04-25 20:49:59 +08:00
2024-11-19 06:19:56 +08:00
def generate_openai_batch_embeddings(
2025-02-05 16:07:45 +08:00
model: str,
texts: list[str],
url: str = "https://api.openai.com/v1",
key: str = "",
2025-02-06 07:15:24 +08:00
prefix: str = None,
2025-03-31 12:55:15 +08:00
user: UserModel = None,
) -> Optional[list[list[float]]]:
2024-04-15 07:15:39 +08:00
try:
2025-04-13 07:35:11 +08:00
log.debug(
f"generate_openai_batch_embeddings:model {model} batch size: {len(texts)}"
)
2025-03-31 12:55:15 +08:00
json_data = {"input": texts, "model": model}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
2025-02-06 06:03:16 +08:00
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
2024-04-15 07:15:39 +08:00
r = requests.post(
2024-04-21 04:15:59 +08:00
f"{url}/embeddings",
2024-04-15 07:15:39 +08:00
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {key}",
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
2024-04-15 07:15:39 +08:00
},
2025-02-06 06:03:16 +08:00
json=json_data,
2024-04-15 07:15:39 +08:00
)
r.raise_for_status()
data = r.json()
if "data" in data:
return [elem["embedding"] for elem in data["data"]]
2024-04-15 07:15:39 +08:00
else:
raise "Something went wrong :/"
except Exception as e:
log.exception(f"Error generating openai batch embeddings: {e}")
2024-04-15 07:15:39 +08:00
return None
2024-04-23 04:49:58 +08:00
2024-11-19 06:19:56 +08:00
def generate_ollama_batch_embeddings(
2025-03-31 12:55:15 +08:00
model: str,
2025-02-06 07:15:24 +08:00
texts: list[str],
url: str,
2025-03-31 12:55:15 +08:00
key: str = "",
prefix: str = None,
user: UserModel = None,
2024-11-19 06:19:56 +08:00
) -> Optional[list[list[float]]]:
try:
2025-04-13 07:35:11 +08:00
log.debug(
f"generate_ollama_batch_embeddings:model {model} batch size: {len(texts)}"
)
2025-03-31 12:55:15 +08:00
json_data = {"input": texts, "model": model}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
2025-02-06 06:03:16 +08:00
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
2025-03-31 12:55:15 +08:00
2024-12-31 08:55:29 +08:00
r = requests.post(
f"{url}/api/embed",
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {key}",
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
),
2024-12-31 08:55:29 +08:00
},
2025-02-06 06:03:16 +08:00
json=json_data,
2024-12-31 08:55:29 +08:00
)
2024-11-19 06:19:56 +08:00
r.raise_for_status()
2024-12-31 08:55:29 +08:00
data = r.json()
if "embeddings" in data:
return data["embeddings"]
else:
raise "Something went wrong :/"
2024-11-19 06:19:56 +08:00
except Exception as e:
log.exception(f"Error generating ollama batch embeddings: {e}")
2024-11-19 06:19:56 +08:00
return None
2025-03-31 12:55:15 +08:00
def generate_embeddings(
engine: str,
model: str,
text: Union[str, list[str]],
prefix: Union[str, None] = None,
**kwargs,
):
2024-11-19 06:19:56 +08:00
url = kwargs.get("url", "")
key = kwargs.get("key", "")
user = kwargs.get("user")
2024-11-19 06:19:56 +08:00
2025-02-06 06:03:16 +08:00
if prefix is not None and RAG_EMBEDDING_PREFIX_FIELD_NAME is None:
if isinstance(text, list):
2025-03-31 12:55:15 +08:00
text = [f"{prefix}{text_element}" for text_element in text]
2025-02-06 06:03:16 +08:00
else:
2025-03-31 12:55:15 +08:00
text = f"{prefix}{text}"
2025-02-06 06:03:16 +08:00
2024-10-10 02:41:35 +08:00
if engine == "ollama":
if isinstance(text, list):
2024-11-19 06:19:56 +08:00
embeddings = generate_ollama_batch_embeddings(
2025-03-31 12:55:15 +08:00
**{
"model": model,
"texts": text,
"url": url,
"key": key,
"prefix": prefix,
"user": user,
}
2024-10-10 02:41:35 +08:00
)
else:
2024-11-19 06:19:56 +08:00
embeddings = generate_ollama_batch_embeddings(
2025-03-31 12:55:15 +08:00
**{
"model": model,
"texts": [text],
"url": url,
"key": key,
"prefix": prefix,
"user": user,
}
2024-10-10 02:41:35 +08:00
)
2024-11-19 06:19:56 +08:00
return embeddings[0] if isinstance(text, str) else embeddings
2024-10-10 02:41:35 +08:00
elif engine == "openai":
if isinstance(text, list):
2025-03-31 12:55:15 +08:00
embeddings = generate_openai_batch_embeddings(
model, text, url, key, prefix, user
)
2024-10-10 02:41:35 +08:00
else:
2025-03-31 12:55:15 +08:00
embeddings = generate_openai_batch_embeddings(
model, [text], url, key, prefix, user
)
2024-10-10 02:41:35 +08:00
return embeddings[0] if isinstance(text, str) else embeddings
import operator
from typing import Optional, Sequence
from langchain_core.callbacks import Callbacks
2024-08-28 06:10:27 +08:00
from langchain_core.documents import BaseDocumentCompressor, Document
class RerankCompressor(BaseDocumentCompressor):
2024-04-28 03:38:50 +08:00
embedding_function: Any
2024-04-30 01:15:58 +08:00
top_n: int
reranking_function: Any
r_score: float
class Config:
2024-09-19 23:05:49 +08:00
extra = "forbid"
arbitrary_types_allowed = True
def compress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
2024-04-30 01:15:58 +08:00
reranking = self.reranking_function is not None
if reranking:
scores = self.reranking_function.predict(
[(query, doc.page_content) for doc in documents]
)
else:
from sentence_transformers import util
query_embedding = self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX)
2024-04-28 03:38:50 +08:00
document_embedding = self.embedding_function(
2025-03-31 12:55:15 +08:00
[doc.page_content for doc in documents], RAG_EMBEDDING_CONTENT_PREFIX
)
scores = util.cos_sim(query_embedding, document_embedding)[0]
2025-05-10 23:33:34 +08:00
docs_with_scores = list(
zip(documents, scores.tolist() if not isinstance(scores, list) else scores)
)
if self.r_score:
docs_with_scores = [
(d, s) for d, s in docs_with_scores if s >= self.r_score
]
2024-04-30 01:15:58 +08:00
result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)
final_results = []
for doc, doc_score in result[: self.top_n]:
metadata = doc.metadata
metadata["score"] = doc_score
doc = Document(
page_content=doc.page_content,
metadata=metadata,
)
final_results.append(doc)
return final_results