| 
									
										
										
										
											2024-03-21 07:11:36 +08:00
										 |  |  | import logging | 
					
						
							| 
									
										
										
										
											2024-08-28 06:10:27 +08:00
										 |  |  | import os | 
					
						
							|  |  |  | from typing import Optional, Union | 
					
						
							| 
									
										
										
										
											2024-03-09 11:26:39 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-28 06:10:27 +08:00
										 |  |  | import requests | 
					
						
							| 
									
										
										
										
											2025-02-27 15:51:39 +08:00
										 |  |  | import hashlib | 
					
						
							| 
									
										
										
										
											2025-03-31 22:43:37 +08:00
										 |  |  | from concurrent.futures import ThreadPoolExecutor | 
					
						
							| 
									
										
										
										
											2025-05-20 10:58:04 +08:00
										 |  |  | import time | 
					
						
							| 
									
										
										
										
											2024-09-10 09:27:50 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-06-17 01:33:11 +08:00
										 |  |  | from urllib.parse import quote | 
					
						
							| 
									
										
										
										
											2024-04-25 20:49:59 +08:00
										 |  |  | from huggingface_hub import snapshot_download | 
					
						
							| 
									
										
										
										
											2024-08-28 06:10:27 +08:00
										 |  |  | from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever | 
					
						
							| 
									
										
										
										
											2024-04-23 07:36:46 +08:00
										 |  |  | from langchain_community.retrievers import BM25Retriever | 
					
						
							| 
									
										
										
										
											2024-08-28 06:10:27 +08:00
										 |  |  | from langchain_core.documents import Document | 
					
						
							| 
									
										
										
										
											2024-09-10 09:27:50 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-08 16:21:50 +08:00
										 |  |  | from open_webui.config import VECTOR_DB | 
					
						
							| 
									
										
										
										
											2025-05-12 23:15:32 +08:00
										 |  |  | from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT | 
					
						
							| 
									
										
										
										
											2025-02-21 03:02:45 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-29 18:55:52 +08:00
										 |  |  | from open_webui.models.users import UserModel | 
					
						
							| 
									
										
										
										
											2025-02-27 07:42:19 +08:00
										 |  |  | from open_webui.models.files import Files | 
					
						
							| 
									
										
										
										
											2025-07-11 16:00:21 +08:00
										 |  |  | from open_webui.models.knowledge import Knowledges | 
					
						
							| 
									
										
										
										
											2025-07-09 05:17:25 +08:00
										 |  |  | from open_webui.models.notes import Notes | 
					
						
							| 
									
										
										
										
											2024-04-15 07:48:15 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-03-31 11:48:22 +08:00
										 |  |  | from open_webui.retrieval.vector.main import GetResult | 
					
						
							| 
									
										
										
										
											2025-07-11 16:00:21 +08:00
										 |  |  | from open_webui.utils.access_control import has_access | 
					
						
							| 
									
										
										
										
											2025-03-31 11:48:22 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-06 07:15:24 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-05 16:07:45 +08:00
										 |  |  | from open_webui.env import ( | 
					
						
							|  |  |  |     SRC_LOG_LEVELS, | 
					
						
							|  |  |  |     OFFLINE_MODE, | 
					
						
							|  |  |  |     ENABLE_FORWARD_USER_INFO_HEADERS, | 
					
						
							|  |  |  | ) | 
					
						
							| 
									
										
										
										
											2025-02-05 05:04:36 +08:00
										 |  |  | from open_webui.config import ( | 
					
						
							| 
									
										
										
										
											2025-03-31 12:55:15 +08:00
										 |  |  |     RAG_EMBEDDING_QUERY_PREFIX, | 
					
						
							|  |  |  |     RAG_EMBEDDING_CONTENT_PREFIX, | 
					
						
							|  |  |  |     RAG_EMBEDDING_PREFIX_FIELD_NAME, | 
					
						
							| 
									
										
										
										
											2025-02-05 05:04:36 +08:00
										 |  |  | ) | 
					
						
							| 
									
										
										
										
											2024-09-10 09:27:50 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-21 07:11:36 +08:00
										 |  |  | log = logging.getLogger(__name__) | 
					
						
							|  |  |  | log.setLevel(SRC_LOG_LEVELS["RAG"]) | 
					
						
							| 
									
										
										
										
											2024-03-09 11:26:39 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-10 11:37:06 +08:00
										 |  |  | from typing import Any | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from langchain_core.callbacks import CallbackManagerForRetrieverRun | 
					
						
							|  |  |  | from langchain_core.retrievers import BaseRetriever | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class VectorSearchRetriever(BaseRetriever): | 
					
						
							|  |  |  |     collection_name: Any | 
					
						
							|  |  |  |     embedding_function: Any | 
					
						
							|  |  |  |     top_k: int | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _get_relevant_documents( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         query: str, | 
					
						
							| 
									
										
										
										
											2024-12-31 08:55:29 +08:00
										 |  |  |         *, | 
					
						
							|  |  |  |         run_manager: CallbackManagerForRetrieverRun, | 
					
						
							| 
									
										
										
										
											2024-09-10 11:37:06 +08:00
										 |  |  |     ) -> list[Document]: | 
					
						
							|  |  |  |         result = VECTOR_DB_CLIENT.search( | 
					
						
							|  |  |  |             collection_name=self.collection_name, | 
					
						
							| 
									
										
										
										
											2025-03-31 12:55:15 +08:00
										 |  |  |             vectors=[self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX)], | 
					
						
							| 
									
										
										
										
											2024-09-10 11:37:06 +08:00
										 |  |  |             limit=self.top_k, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-13 13:18:20 +08:00
										 |  |  |         ids = result.ids[0] | 
					
						
							|  |  |  |         metadatas = result.metadatas[0] | 
					
						
							|  |  |  |         documents = result.documents[0] | 
					
						
							| 
									
										
										
										
											2024-09-10 11:37:06 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-31 08:55:29 +08:00
										 |  |  |         results = [] | 
					
						
							|  |  |  |         for idx in range(len(ids)): | 
					
						
							|  |  |  |             results.append( | 
					
						
							|  |  |  |                 Document( | 
					
						
							|  |  |  |                     metadata=metadatas[idx], | 
					
						
							|  |  |  |                     page_content=documents[idx], | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |         return results | 
					
						
							| 
									
										
										
										
											2024-09-10 11:37:06 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  | def query_doc( | 
					
						
							| 
									
										
										
										
											2025-02-05 16:07:45 +08:00
										 |  |  |     collection_name: str, query_embedding: list[float], k: int, user: UserModel = None | 
					
						
							| 
									
										
										
										
											2024-04-23 04:49:58 +08:00
										 |  |  | ): | 
					
						
							| 
									
										
										
										
											2024-04-15 05:55:00 +08:00
										 |  |  |     try: | 
					
						
							| 
									
										
										
										
											2025-04-09 00:08:32 +08:00
										 |  |  |         log.debug(f"query_doc:doc {collection_name}") | 
					
						
							| 
									
										
										
										
											2024-12-31 08:55:29 +08:00
										 |  |  |         result = VECTOR_DB_CLIENT.search( | 
					
						
							| 
									
										
										
										
											2024-09-10 11:37:06 +08:00
										 |  |  |             collection_name=collection_name, | 
					
						
							| 
									
										
										
										
											2024-09-30 22:18:02 +08:00
										 |  |  |             vectors=[query_embedding], | 
					
						
							| 
									
										
										
										
											2024-09-10 11:37:06 +08:00
										 |  |  |             limit=k, | 
					
						
							| 
									
										
										
										
											2024-12-31 08:55:29 +08:00
										 |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if result: | 
					
						
							| 
									
										
										
										
											2024-12-20 12:56:16 +08:00
										 |  |  |             log.info(f"query_doc:result {result.ids} {result.metadatas}") | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-31 08:55:29 +08:00
										 |  |  |         return result | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  |     except Exception as e: | 
					
						
							| 
									
										
										
										
											2025-02-25 22:36:25 +08:00
										 |  |  |         log.exception(f"Error querying doc {collection_name} with limit {k}: {e}") | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  |         raise e | 
					
						
							| 
									
										
										
										
											2024-04-26 05:03:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-19 13:14:58 +08:00
										 |  |  | def get_doc(collection_name: str, user: UserModel = None): | 
					
						
							|  |  |  |     try: | 
					
						
							| 
									
										
										
										
											2025-04-09 00:08:32 +08:00
										 |  |  |         log.debug(f"get_doc:doc {collection_name}") | 
					
						
							| 
									
										
										
										
											2025-02-19 13:14:58 +08:00
										 |  |  |         result = VECTOR_DB_CLIENT.get(collection_name=collection_name) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if result: | 
					
						
							|  |  |  |             log.info(f"query_doc:result {result.ids} {result.metadatas}") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return result | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							| 
									
										
										
										
											2025-02-25 22:36:25 +08:00
										 |  |  |         log.exception(f"Error getting doc {collection_name}: {e}") | 
					
						
							| 
									
										
										
										
											2025-02-19 13:14:58 +08:00
										 |  |  |         raise e | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  | def query_doc_with_hybrid_search( | 
					
						
							|  |  |  |     collection_name: str, | 
					
						
							| 
									
										
										
										
											2025-03-31 11:48:22 +08:00
										 |  |  |     collection_result: GetResult, | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  |     query: str, | 
					
						
							|  |  |  |     embedding_function, | 
					
						
							|  |  |  |     k: int, | 
					
						
							|  |  |  |     reranking_function, | 
					
						
							| 
									
										
										
										
											2025-03-06 17:47:57 +08:00
										 |  |  |     k_reranker: int, | 
					
						
							| 
									
										
										
										
											2024-05-02 13:45:19 +08:00
										 |  |  |     r: float, | 
					
						
							| 
									
										
										
										
											2025-05-24 04:06:44 +08:00
										 |  |  |     hybrid_bm25_weight: float, | 
					
						
							| 
									
										
										
										
											2024-09-12 21:50:18 +08:00
										 |  |  | ) -> dict: | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  |     try: | 
					
						
							| 
									
										
										
										
											2025-04-09 00:08:32 +08:00
										 |  |  |         log.debug(f"query_doc_with_hybrid_search:doc {collection_name}") | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  |         bm25_retriever = BM25Retriever.from_texts( | 
					
						
							| 
									
										
										
										
											2025-03-31 11:48:22 +08:00
										 |  |  |             texts=collection_result.documents[0], | 
					
						
							|  |  |  |             metadatas=collection_result.metadatas[0], | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  |         ) | 
					
						
							|  |  |  |         bm25_retriever.k = k | 
					
						
							| 
									
										
										
										
											2024-04-26 05:03:00 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-10 11:37:06 +08:00
										 |  |  |         vector_search_retriever = VectorSearchRetriever( | 
					
						
							|  |  |  |             collection_name=collection_name, | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  |             embedding_function=embedding_function, | 
					
						
							| 
									
										
										
										
											2024-09-10 11:37:06 +08:00
										 |  |  |             top_k=k, | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-04-26 05:03:00 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-05-24 04:06:44 +08:00
										 |  |  |         if hybrid_bm25_weight <= 0: | 
					
						
							| 
									
										
										
										
											2025-05-20 16:39:31 +08:00
										 |  |  |             ensemble_retriever = EnsembleRetriever( | 
					
						
							| 
									
										
										
										
											2025-05-24 06:13:54 +08:00
										 |  |  |                 retrievers=[vector_search_retriever], weights=[1.0] | 
					
						
							| 
									
										
										
										
											2025-05-20 16:39:31 +08:00
										 |  |  |             ) | 
					
						
							| 
									
										
										
										
											2025-05-24 04:06:44 +08:00
										 |  |  |         elif hybrid_bm25_weight >= 1: | 
					
						
							| 
									
										
										
										
											2025-05-20 16:39:31 +08:00
										 |  |  |             ensemble_retriever = EnsembleRetriever( | 
					
						
							| 
									
										
										
										
											2025-05-24 06:13:54 +08:00
										 |  |  |                 retrievers=[bm25_retriever], weights=[1.0] | 
					
						
							| 
									
										
										
										
											2025-05-20 16:39:31 +08:00
										 |  |  |             ) | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             ensemble_retriever = EnsembleRetriever( | 
					
						
							|  |  |  |                 retrievers=[bm25_retriever, vector_search_retriever], | 
					
						
							| 
									
										
										
										
											2025-05-24 06:13:54 +08:00
										 |  |  |                 weights=[hybrid_bm25_weight, 1.0 - hybrid_bm25_weight], | 
					
						
							| 
									
										
										
										
											2025-05-20 16:39:31 +08:00
										 |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  |         compressor = RerankCompressor( | 
					
						
							|  |  |  |             embedding_function=embedding_function, | 
					
						
							| 
									
										
										
										
											2025-03-06 17:47:57 +08:00
										 |  |  |             top_n=k_reranker, | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  |             reranking_function=reranking_function, | 
					
						
							|  |  |  |             r_score=r, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         compression_retriever = ContextualCompressionRetriever( | 
					
						
							|  |  |  |             base_compressor=compressor, base_retriever=ensemble_retriever | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-04-26 05:03:00 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  |         result = compression_retriever.invoke(query) | 
					
						
							| 
									
										
										
										
											2025-03-18 18:31:17 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-03-31 11:48:22 +08:00
										 |  |  |         distances = [d.metadata.get("score") for d in result] | 
					
						
							|  |  |  |         documents = [d.page_content for d in result] | 
					
						
							|  |  |  |         metadatas = [d.metadata for d in result] | 
					
						
							| 
									
										
										
										
											2025-03-18 18:31:17 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         # retrieve only min(k, k_reranker) items, sort and cut by distance if k < k_reranker | 
					
						
							|  |  |  |         if k < k_reranker: | 
					
						
							| 
									
										
										
										
											2025-03-27 16:40:28 +08:00
										 |  |  |             sorted_items = sorted( | 
					
						
							|  |  |  |                 zip(distances, metadatas, documents), key=lambda x: x[0], reverse=True | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2025-03-18 18:31:17 +08:00
										 |  |  |             sorted_items = sorted_items[:k] | 
					
						
							|  |  |  |             distances, documents, metadatas = map(list, zip(*sorted_items)) | 
					
						
							| 
									
										
										
										
											2025-03-31 11:48:22 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         result = { | 
					
						
							| 
									
										
										
										
											2025-03-18 18:31:17 +08:00
										 |  |  |             "distances": [distances], | 
					
						
							| 
									
										
										
										
											2025-03-18 19:14:59 +08:00
										 |  |  |             "documents": [documents], | 
					
						
							| 
									
										
										
										
											2025-03-18 18:31:17 +08:00
										 |  |  |             "metadatas": [metadatas], | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  |         } | 
					
						
							| 
									
										
										
										
											2024-04-30 01:15:58 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-29 20:33:37 +08:00
										 |  |  |         log.info( | 
					
						
							| 
									
										
										
										
											2024-11-07 15:01:10 +08:00
										 |  |  |             "query_doc_with_hybrid_search:result " | 
					
						
							| 
									
										
										
										
											2025-03-31 11:48:22 +08:00
										 |  |  |             + f'{result["metadatas"]} {result["distances"]}' | 
					
						
							| 
									
										
										
										
											2024-10-29 20:33:37 +08:00
										 |  |  |         ) | 
					
						
							| 
									
										
										
										
											2025-03-31 11:48:22 +08:00
										 |  |  |         return result | 
					
						
							| 
									
										
										
										
											2024-04-15 05:55:00 +08:00
										 |  |  |     except Exception as e: | 
					
						
							| 
									
										
										
										
											2025-04-09 00:08:32 +08:00
										 |  |  |         log.exception(f"Error querying doc {collection_name} with hybrid search: {e}") | 
					
						
							| 
									
										
										
										
											2024-04-15 05:55:00 +08:00
										 |  |  |         raise e | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-19 13:14:58 +08:00
										 |  |  | def merge_get_results(get_results: list[dict]) -> dict: | 
					
						
							|  |  |  |     # Initialize lists to store combined data | 
					
						
							|  |  |  |     combined_documents = [] | 
					
						
							|  |  |  |     combined_metadatas = [] | 
					
						
							| 
									
										
										
										
											2025-02-19 15:49:27 +08:00
										 |  |  |     combined_ids = [] | 
					
						
							| 
									
										
										
										
											2025-02-19 13:14:58 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     for data in get_results: | 
					
						
							|  |  |  |         combined_documents.extend(data["documents"][0]) | 
					
						
							|  |  |  |         combined_metadatas.extend(data["metadatas"][0]) | 
					
						
							| 
									
										
										
										
											2025-02-19 15:49:27 +08:00
										 |  |  |         combined_ids.extend(data["ids"][0]) | 
					
						
							| 
									
										
										
										
											2025-02-19 13:14:58 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # Create the output dictionary | 
					
						
							|  |  |  |     result = { | 
					
						
							|  |  |  |         "documents": [combined_documents], | 
					
						
							|  |  |  |         "metadatas": [combined_metadatas], | 
					
						
							| 
									
										
										
										
											2025-02-19 15:49:27 +08:00
										 |  |  |         "ids": [combined_ids], | 
					
						
							| 
									
										
										
										
											2025-02-19 13:14:58 +08:00
										 |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return result | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-03-26 02:09:17 +08:00
										 |  |  | def merge_and_sort_query_results(query_results: list[dict], k: int) -> dict: | 
					
						
							| 
									
										
										
										
											2024-12-31 08:55:29 +08:00
										 |  |  |     # Initialize lists to store combined data | 
					
						
							| 
									
										
										
										
											2025-03-19 23:06:10 +08:00
										 |  |  |     combined = dict()  # To store documents with unique document hashes | 
					
						
							| 
									
										
										
										
											2024-12-31 08:55:29 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     for data in query_results: | 
					
						
							| 
									
										
										
										
											2025-02-27 15:51:39 +08:00
										 |  |  |         distances = data["distances"][0] | 
					
						
							|  |  |  |         documents = data["documents"][0] | 
					
						
							|  |  |  |         metadatas = data["metadatas"][0] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         for distance, document, metadata in zip(distances, documents, metadatas): | 
					
						
							|  |  |  |             if isinstance(document, str): | 
					
						
							| 
									
										
										
										
											2025-05-01 15:56:20 +08:00
										 |  |  |                 doc_hash = hashlib.sha256( | 
					
						
							| 
									
										
										
										
											2025-02-27 15:51:39 +08:00
										 |  |  |                     document.encode() | 
					
						
							|  |  |  |                 ).hexdigest()  # Compute a hash for uniqueness | 
					
						
							| 
									
										
										
										
											2024-12-31 08:55:29 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-03-19 23:06:10 +08:00
										 |  |  |                 if doc_hash not in combined.keys(): | 
					
						
							|  |  |  |                     combined[doc_hash] = (distance, document, metadata) | 
					
						
							|  |  |  |                     continue  # if doc is new, no further comparison is needed | 
					
						
							| 
									
										
										
										
											2024-12-31 08:55:29 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-03-19 23:06:10 +08:00
										 |  |  |                 # if doc is alredy in, but new distance is better, update | 
					
						
							| 
									
										
										
										
											2025-03-25 23:46:14 +08:00
										 |  |  |                 if distance > combined[doc_hash][0]: | 
					
						
							| 
									
										
										
										
											2025-03-19 23:06:10 +08:00
										 |  |  |                     combined[doc_hash] = (distance, document, metadata) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     combined = list(combined.values()) | 
					
						
							| 
									
										
										
										
											2024-12-31 08:55:29 +08:00
										 |  |  |     # Sort the list based on distances | 
					
						
							| 
									
										
										
										
											2025-03-25 23:46:14 +08:00
										 |  |  |     combined.sort(key=lambda x: x[0], reverse=True) | 
					
						
							| 
									
										
										
										
											2024-12-31 08:55:29 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-27 15:51:39 +08:00
										 |  |  |     # Slice to keep only the top k elements | 
					
						
							|  |  |  |     sorted_distances, sorted_documents, sorted_metadatas = ( | 
					
						
							|  |  |  |         zip(*combined[:k]) if combined else ([], [], []) | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2025-02-21 03:02:45 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-27 15:51:39 +08:00
										 |  |  |     # Create and return the output dictionary | 
					
						
							|  |  |  |     return { | 
					
						
							|  |  |  |         "distances": [list(sorted_distances)], | 
					
						
							|  |  |  |         "documents": [list(sorted_documents)], | 
					
						
							|  |  |  |         "metadatas": [list(sorted_metadatas)], | 
					
						
							| 
									
										
										
										
											2024-12-31 08:55:29 +08:00
										 |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-09 11:26:39 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-19 13:14:58 +08:00
										 |  |  | def get_all_items_from_collections(collection_names: list[str]) -> dict: | 
					
						
							|  |  |  |     results = [] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     for collection_name in collection_names: | 
					
						
							|  |  |  |         if collection_name: | 
					
						
							|  |  |  |             try: | 
					
						
							|  |  |  |                 result = get_doc(collection_name=collection_name) | 
					
						
							|  |  |  |                 if result is not None: | 
					
						
							|  |  |  |                     results.append(result.model_dump()) | 
					
						
							|  |  |  |             except Exception as e: | 
					
						
							|  |  |  |                 log.exception(f"Error when querying the collection: {e}") | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             pass | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return merge_get_results(results) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  | def query_collection( | 
					
						
							| 
									
										
										
										
											2024-08-14 20:46:31 +08:00
										 |  |  |     collection_names: list[str], | 
					
						
							| 
									
										
										
										
											2024-11-19 18:24:32 +08:00
										 |  |  |     queries: list[str], | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  |     embedding_function, | 
					
						
							|  |  |  |     k: int, | 
					
						
							| 
									
										
										
										
											2024-09-12 21:50:18 +08:00
										 |  |  | ) -> dict: | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  |     results = [] | 
					
						
							| 
									
										
										
										
											2025-04-23 17:17:12 +08:00
										 |  |  |     error = False | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def process_query_collection(collection_name, query_embedding): | 
					
						
							|  |  |  |         try: | 
					
						
							| 
									
										
										
										
											2024-12-31 08:55:29 +08:00
										 |  |  |             if collection_name: | 
					
						
							| 
									
										
										
										
											2025-04-23 17:17:12 +08:00
										 |  |  |                 result = query_doc( | 
					
						
							|  |  |  |                     collection_name=collection_name, | 
					
						
							|  |  |  |                     k=k, | 
					
						
							|  |  |  |                     query_embedding=query_embedding, | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |                 if result is not None: | 
					
						
							|  |  |  |                     return result.model_dump(), None | 
					
						
							|  |  |  |             return None, None | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							|  |  |  |             log.exception(f"Error when querying the collection: {e}") | 
					
						
							|  |  |  |             return None, e | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Generate all query embeddings (in one call) | 
					
						
							|  |  |  |     query_embeddings = embedding_function(queries, prefix=RAG_EMBEDDING_QUERY_PREFIX) | 
					
						
							|  |  |  |     log.debug( | 
					
						
							|  |  |  |         f"query_collection: processing {len(queries)} queries across {len(collection_names)} collections" | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     with ThreadPoolExecutor() as executor: | 
					
						
							|  |  |  |         future_results = [] | 
					
						
							|  |  |  |         for query_embedding in query_embeddings: | 
					
						
							|  |  |  |             for collection_name in collection_names: | 
					
						
							|  |  |  |                 result = executor.submit( | 
					
						
							|  |  |  |                     process_query_collection, collection_name, query_embedding | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |                 future_results.append(result) | 
					
						
							|  |  |  |         task_results = [future.result() for future in future_results] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     for result, err in task_results: | 
					
						
							|  |  |  |         if err is not None: | 
					
						
							|  |  |  |             error = True | 
					
						
							|  |  |  |         elif result is not None: | 
					
						
							|  |  |  |             results.append(result) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if error and not results: | 
					
						
							|  |  |  |         log.warning("All collection queries failed. No results returned.") | 
					
						
							| 
									
										
										
										
											2024-08-23 21:02:23 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-03-25 23:46:14 +08:00
										 |  |  |     return merge_and_sort_query_results(results, k=k) | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def query_collection_with_hybrid_search( | 
					
						
							| 
									
										
										
										
											2024-08-14 20:46:31 +08:00
										 |  |  |     collection_names: list[str], | 
					
						
							| 
									
										
										
										
											2024-11-19 18:24:32 +08:00
										 |  |  |     queries: list[str], | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  |     embedding_function, | 
					
						
							| 
									
										
										
										
											2024-04-23 04:49:58 +08:00
										 |  |  |     k: int, | 
					
						
							|  |  |  |     reranking_function, | 
					
						
							| 
									
										
										
										
											2025-03-06 17:47:57 +08:00
										 |  |  |     k_reranker: int, | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  |     r: float, | 
					
						
							| 
									
										
										
										
											2025-05-24 04:06:44 +08:00
										 |  |  |     hybrid_bm25_weight: float, | 
					
						
							| 
									
										
										
										
											2024-09-12 21:50:18 +08:00
										 |  |  | ) -> dict: | 
					
						
							| 
									
										
										
										
											2024-04-15 05:55:00 +08:00
										 |  |  |     results = [] | 
					
						
							| 
									
										
										
										
											2024-09-13 13:18:20 +08:00
										 |  |  |     error = False | 
					
						
							| 
									
										
										
										
											2025-03-28 02:05:20 +08:00
										 |  |  |     # Fetch collection data once per collection sequentially | 
					
						
							|  |  |  |     # Avoid fetching the same data multiple times later | 
					
						
							| 
									
										
										
										
											2025-03-31 11:48:22 +08:00
										 |  |  |     collection_results = {} | 
					
						
							| 
									
										
										
										
											2025-03-28 02:05:20 +08:00
										 |  |  |     for collection_name in collection_names: | 
					
						
							|  |  |  |         try: | 
					
						
							| 
									
										
										
										
											2025-04-13 07:35:11 +08:00
										 |  |  |             log.debug( | 
					
						
							|  |  |  |                 f"query_collection_with_hybrid_search:VECTOR_DB_CLIENT.get:collection {collection_name}" | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2025-03-31 11:48:22 +08:00
										 |  |  |             collection_results[collection_name] = VECTOR_DB_CLIENT.get( | 
					
						
							|  |  |  |                 collection_name=collection_name | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2025-03-28 02:05:20 +08:00
										 |  |  |         except Exception as e: | 
					
						
							|  |  |  |             log.exception(f"Failed to fetch collection {collection_name}: {e}") | 
					
						
							| 
									
										
										
										
											2025-03-31 11:48:22 +08:00
										 |  |  |             collection_results[collection_name] = None | 
					
						
							| 
									
										
										
										
											2025-03-28 02:05:20 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-04-01 08:59:21 +08:00
										 |  |  |     log.info( | 
					
						
							|  |  |  |         f"Starting hybrid search for {len(queries)} queries in {len(collection_names)} collections..." | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-03-31 22:43:37 +08:00
										 |  |  |     def process_query(collection_name, query): | 
					
						
							| 
									
										
										
										
											2024-12-31 08:55:29 +08:00
										 |  |  |         try: | 
					
						
							| 
									
										
										
										
											2025-03-31 22:43:37 +08:00
										 |  |  |             result = query_doc_with_hybrid_search( | 
					
						
							|  |  |  |                 collection_name=collection_name, | 
					
						
							|  |  |  |                 collection_result=collection_results[collection_name], | 
					
						
							|  |  |  |                 query=query, | 
					
						
							|  |  |  |                 embedding_function=embedding_function, | 
					
						
							|  |  |  |                 k=k, | 
					
						
							|  |  |  |                 reranking_function=reranking_function, | 
					
						
							|  |  |  |                 k_reranker=k_reranker, | 
					
						
							|  |  |  |                 r=r, | 
					
						
							| 
									
										
										
										
											2025-05-24 04:06:44 +08:00
										 |  |  |                 hybrid_bm25_weight=hybrid_bm25_weight, | 
					
						
							| 
									
										
										
										
											2024-12-31 08:55:29 +08:00
										 |  |  |             ) | 
					
						
							| 
									
										
										
										
											2025-03-31 22:43:37 +08:00
										 |  |  |             return result, None | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							|  |  |  |             log.exception(f"Error when querying the collection with hybrid_search: {e}") | 
					
						
							|  |  |  |             return None, e | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-04-05 16:41:21 +08:00
										 |  |  |     # Prepare tasks for all collections and queries | 
					
						
							|  |  |  |     # Avoid running any tasks for collections that failed to fetch data (have assigned None) | 
					
						
							| 
									
										
										
										
											2025-04-05 22:03:24 +08:00
										 |  |  |     tasks = [ | 
					
						
							|  |  |  |         (cn, q) | 
					
						
							|  |  |  |         for cn in collection_names | 
					
						
							|  |  |  |         if collection_results[cn] is not None | 
					
						
							|  |  |  |         for q in queries | 
					
						
							|  |  |  |     ] | 
					
						
							| 
									
										
										
										
											2025-03-31 22:43:37 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     with ThreadPoolExecutor() as executor: | 
					
						
							|  |  |  |         future_results = [executor.submit(process_query, cn, q) for cn, q in tasks] | 
					
						
							|  |  |  |         task_results = [future.result() for future in future_results] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     for result, err in task_results: | 
					
						
							|  |  |  |         if err is not None: | 
					
						
							| 
									
										
										
										
											2024-12-31 08:55:29 +08:00
										 |  |  |             error = True | 
					
						
							| 
									
										
										
										
											2025-03-31 22:43:37 +08:00
										 |  |  |         elif result is not None: | 
					
						
							|  |  |  |             results.append(result) | 
					
						
							| 
									
										
										
										
											2024-09-13 13:18:20 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-03-31 22:43:37 +08:00
										 |  |  |     if error and not results: | 
					
						
							| 
									
										
										
										
											2025-04-01 08:59:21 +08:00
										 |  |  |         raise Exception( | 
					
						
							|  |  |  |             "Hybrid search failed for all collections. Using Non-hybrid search as fallback." | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-03-25 23:46:14 +08:00
										 |  |  |     return merge_and_sort_query_results(results, k=k) | 
					
						
							| 
									
										
										
										
											2024-04-15 05:55:00 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-03-27 16:40:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  | def get_embedding_function( | 
					
						
							| 
									
										
										
										
											2024-04-23 04:49:58 +08:00
										 |  |  |     embedding_engine, | 
					
						
							|  |  |  |     embedding_model, | 
					
						
							|  |  |  |     embedding_function, | 
					
						
							| 
									
										
										
										
											2024-11-19 06:19:56 +08:00
										 |  |  |     url, | 
					
						
							|  |  |  |     key, | 
					
						
							| 
									
										
										
										
											2025-02-05 16:07:45 +08:00
										 |  |  |     embedding_batch_size, | 
					
						
							| 
									
										
										
										
											2025-05-30 04:34:18 +08:00
										 |  |  |     azure_api_version=None, | 
					
						
							| 
									
										
										
										
											2024-04-23 04:49:58 +08:00
										 |  |  | ): | 
					
						
							|  |  |  |     if embedding_engine == "": | 
					
						
							| 
									
										
										
										
											2025-04-01 05:13:27 +08:00
										 |  |  |         return lambda query, prefix=None, user=None: embedding_function.encode( | 
					
						
							| 
									
										
										
										
											2025-04-07 08:17:24 +08:00
										 |  |  |             query, **({"prompt": prefix} if prefix else {}) | 
					
						
							| 
									
										
										
										
											2025-03-31 12:55:15 +08:00
										 |  |  |         ).tolist() | 
					
						
							| 
									
										
										
										
											2025-05-20 10:58:04 +08:00
										 |  |  |     elif embedding_engine in ["ollama", "openai", "azure_openai"]: | 
					
						
							| 
									
										
										
										
											2025-04-01 05:13:27 +08:00
										 |  |  |         func = lambda query, prefix=None, user=None: generate_embeddings( | 
					
						
							| 
									
										
										
										
											2024-10-10 03:05:16 +08:00
										 |  |  |             engine=embedding_engine, | 
					
						
							|  |  |  |             model=embedding_model, | 
					
						
							|  |  |  |             text=query, | 
					
						
							| 
									
										
										
										
											2025-01-16 09:05:04 +08:00
										 |  |  |             prefix=prefix, | 
					
						
							| 
									
										
										
										
											2024-11-19 06:19:56 +08:00
										 |  |  |             url=url, | 
					
						
							|  |  |  |             key=key, | 
					
						
							| 
									
										
										
										
											2025-02-05 16:07:45 +08:00
										 |  |  |             user=user, | 
					
						
							| 
									
										
										
										
											2025-05-30 04:34:18 +08:00
										 |  |  |             azure_api_version=azure_api_version, | 
					
						
							| 
									
										
										
										
											2024-10-10 03:05:16 +08:00
										 |  |  |         ) | 
					
						
							| 
									
										
										
										
											2025-03-31 12:55:15 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-06 07:15:24 +08:00
										 |  |  |         def generate_multiple(query, prefix, user, func): | 
					
						
							| 
									
										
										
										
											2024-04-23 07:36:46 +08:00
										 |  |  |             if isinstance(query, list): | 
					
						
							| 
									
										
										
										
											2024-12-31 08:55:29 +08:00
										 |  |  |                 embeddings = [] | 
					
						
							|  |  |  |                 for i in range(0, len(query), embedding_batch_size): | 
					
						
							| 
									
										
										
										
											2025-02-05 16:07:45 +08:00
										 |  |  |                     embeddings.extend( | 
					
						
							| 
									
										
										
										
											2025-03-31 12:55:15 +08:00
										 |  |  |                         func( | 
					
						
							|  |  |  |                             query[i : i + embedding_batch_size], | 
					
						
							|  |  |  |                             prefix=prefix, | 
					
						
							|  |  |  |                             user=user, | 
					
						
							|  |  |  |                         ) | 
					
						
							| 
									
										
										
										
											2025-02-05 16:07:45 +08:00
										 |  |  |                     ) | 
					
						
							| 
									
										
										
										
											2024-12-31 08:55:29 +08:00
										 |  |  |                 return embeddings | 
					
						
							| 
									
										
										
										
											2024-04-23 07:36:46 +08:00
										 |  |  |             else: | 
					
						
							| 
									
										
										
										
											2025-02-06 07:15:24 +08:00
										 |  |  |                 return func(query, prefix, user) | 
					
						
							| 
									
										
										
										
											2025-03-31 12:55:15 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-04-01 05:13:27 +08:00
										 |  |  |         return lambda query, prefix=None, user=None: generate_multiple( | 
					
						
							| 
									
										
										
										
											2025-03-31 12:55:15 +08:00
										 |  |  |             query, prefix, user, func | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2025-01-29 18:55:52 +08:00
										 |  |  |     else: | 
					
						
							|  |  |  |         raise ValueError(f"Unknown embedding engine: {embedding_engine}") | 
					
						
							| 
									
										
										
										
											2024-04-23 04:49:58 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-07-14 17:59:10 +08:00
										 |  |  | def get_reranking_function(reranking_engine, reranking_model, reranking_function): | 
					
						
							| 
									
										
										
										
											2025-07-14 18:05:06 +08:00
										 |  |  |     if reranking_function is None: | 
					
						
							|  |  |  |         return None | 
					
						
							| 
									
										
										
										
											2025-07-14 17:59:10 +08:00
										 |  |  |     if reranking_engine == "external": | 
					
						
							|  |  |  |         return lambda sentences, user=None: reranking_function.predict( | 
					
						
							|  |  |  |             sentences, user=user | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         return lambda sentences, user=None: reranking_function.predict(sentences) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-07-11 16:00:21 +08:00
										 |  |  | def get_sources_from_items( | 
					
						
							| 
									
										
										
										
											2025-02-27 07:42:19 +08:00
										 |  |  |     request, | 
					
						
							| 
									
										
										
										
											2025-07-11 16:00:21 +08:00
										 |  |  |     items, | 
					
						
							| 
									
										
										
										
											2024-11-19 18:24:32 +08:00
										 |  |  |     queries, | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  |     embedding_function, | 
					
						
							| 
									
										
										
										
											2024-04-15 07:48:15 +08:00
										 |  |  |     k, | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  |     reranking_function, | 
					
						
							| 
									
										
										
										
											2025-03-06 17:47:57 +08:00
										 |  |  |     k_reranker, | 
					
						
							| 
									
										
										
										
											2024-04-23 07:36:46 +08:00
										 |  |  |     r, | 
					
						
							| 
									
										
										
										
											2025-05-24 04:06:44 +08:00
										 |  |  |     hybrid_bm25_weight, | 
					
						
							| 
									
										
										
										
											2024-04-27 02:41:39 +08:00
										 |  |  |     hybrid_search, | 
					
						
							| 
									
										
										
										
											2025-02-19 13:14:58 +08:00
										 |  |  |     full_context=False, | 
					
						
							| 
									
										
										
										
											2025-07-11 16:00:21 +08:00
										 |  |  |     user: Optional[UserModel] = None, | 
					
						
							| 
									
										
										
										
											2024-04-15 07:48:15 +08:00
										 |  |  | ): | 
					
						
							| 
									
										
										
										
											2025-02-19 13:14:58 +08:00
										 |  |  |     log.debug( | 
					
						
							| 
									
										
										
										
											2025-07-11 16:00:21 +08:00
										 |  |  |         f"items: {items} {queries} {embedding_function} {reranking_function} {full_context}" | 
					
						
							| 
									
										
										
										
											2025-02-19 13:14:58 +08:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2024-03-11 09:40:50 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-23 07:36:46 +08:00
										 |  |  |     extracted_collections = [] | 
					
						
							| 
									
										
										
										
											2025-06-25 16:20:08 +08:00
										 |  |  |     query_results = [] | 
					
						
							| 
									
										
										
										
											2024-03-11 09:40:50 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-07-11 16:00:21 +08:00
										 |  |  |     for item in items: | 
					
						
							| 
									
										
										
										
											2025-06-25 16:20:08 +08:00
										 |  |  |         query_result = None | 
					
						
							| 
									
										
										
										
											2025-07-11 16:29:17 +08:00
										 |  |  |         collection_names = [] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-07-11 16:00:21 +08:00
										 |  |  |         if item.get("type") == "text": | 
					
						
							| 
									
										
										
										
											2025-07-11 16:29:17 +08:00
										 |  |  |             # Raw Text | 
					
						
							| 
									
										
										
										
											2025-07-11 16:00:21 +08:00
										 |  |  |             # Used during temporary chat file uploads | 
					
						
							| 
									
										
										
										
											2025-07-11 16:35:42 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |             if item.get("file"): | 
					
						
							|  |  |  |                 # if item has file data, use it | 
					
						
							|  |  |  |                 query_result = { | 
					
						
							| 
									
										
										
										
											2025-07-14 21:50:03 +08:00
										 |  |  |                     "documents": [ | 
					
						
							|  |  |  |                         [item.get("file", {}).get("data", {}).get("content")] | 
					
						
							|  |  |  |                     ], | 
					
						
							|  |  |  |                     "metadatas": [ | 
					
						
							|  |  |  |                         [item.get("file", {}).get("data", {}).get("meta", {})] | 
					
						
							|  |  |  |                     ], | 
					
						
							| 
									
										
										
										
											2025-07-11 16:35:42 +08:00
										 |  |  |                 } | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 # Fallback to item content | 
					
						
							|  |  |  |                 query_result = { | 
					
						
							|  |  |  |                     "documents": [[item.get("content")]], | 
					
						
							|  |  |  |                     "metadatas": [ | 
					
						
							|  |  |  |                         [{"file_id": item.get("id"), "name": item.get("name")}] | 
					
						
							|  |  |  |                     ], | 
					
						
							|  |  |  |                 } | 
					
						
							| 
									
										
										
										
											2025-07-11 16:00:21 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         elif item.get("type") == "note": | 
					
						
							| 
									
										
										
										
											2025-07-09 05:17:25 +08:00
										 |  |  |             # Note Attached | 
					
						
							| 
									
										
										
										
											2025-07-11 16:00:21 +08:00
										 |  |  |             note = Notes.get_note_by_id(item.get("id")) | 
					
						
							| 
									
										
										
										
											2025-07-09 05:17:25 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-07-22 15:38:47 +08:00
										 |  |  |             if note and ( | 
					
						
							| 
									
										
										
										
											2025-07-22 21:17:26 +08:00
										 |  |  |                 user.role == "admin" | 
					
						
							|  |  |  |                 or note.user_id == user.id | 
					
						
							|  |  |  |                 or has_access(user.id, "read", note.access_control) | 
					
						
							| 
									
										
										
										
											2025-07-22 15:38:47 +08:00
										 |  |  |             ): | 
					
						
							| 
									
										
										
										
											2025-07-11 16:00:21 +08:00
										 |  |  |                 # User has access to the note | 
					
						
							|  |  |  |                 query_result = { | 
					
						
							|  |  |  |                     "documents": [[note.data.get("content", {}).get("md", "")]], | 
					
						
							|  |  |  |                     "metadatas": [[{"file_id": note.id, "name": note.title}]], | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-07-11 16:29:17 +08:00
										 |  |  |         elif item.get("type") == "file": | 
					
						
							|  |  |  |             if ( | 
					
						
							|  |  |  |                 item.get("context") == "full" | 
					
						
							|  |  |  |                 or request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL | 
					
						
							|  |  |  |             ): | 
					
						
							| 
									
										
										
										
											2025-07-14 21:50:03 +08:00
										 |  |  |                 if item.get("file", {}).get("data", {}).get("content", ""): | 
					
						
							| 
									
										
										
										
											2025-07-11 16:29:17 +08:00
										 |  |  |                     # Manual Full Mode Toggle | 
					
						
							|  |  |  |                     # Used from chat file modal, we can assume that the file content will be available from item.get("file").get("data", {}).get("content") | 
					
						
							|  |  |  |                     query_result = { | 
					
						
							|  |  |  |                         "documents": [ | 
					
						
							| 
									
										
										
										
											2025-07-14 21:50:03 +08:00
										 |  |  |                             [item.get("file", {}).get("data", {}).get("content", "")] | 
					
						
							| 
									
										
										
										
											2025-07-11 16:29:17 +08:00
										 |  |  |                         ], | 
					
						
							|  |  |  |                         "metadatas": [ | 
					
						
							|  |  |  |                             [ | 
					
						
							|  |  |  |                                 { | 
					
						
							|  |  |  |                                     "file_id": item.get("id"), | 
					
						
							|  |  |  |                                     "name": item.get("name"), | 
					
						
							|  |  |  |                                     **item.get("file") | 
					
						
							|  |  |  |                                     .get("data", {}) | 
					
						
							|  |  |  |                                     .get("metadata", {}), | 
					
						
							|  |  |  |                                 } | 
					
						
							|  |  |  |                             ] | 
					
						
							|  |  |  |                         ], | 
					
						
							|  |  |  |                     } | 
					
						
							|  |  |  |                 elif item.get("id"): | 
					
						
							|  |  |  |                     file_object = Files.get_file_by_id(item.get("id")) | 
					
						
							|  |  |  |                     if file_object: | 
					
						
							|  |  |  |                         query_result = { | 
					
						
							|  |  |  |                             "documents": [[file_object.data.get("content", "")]], | 
					
						
							|  |  |  |                             "metadatas": [ | 
					
						
							|  |  |  |                                 [ | 
					
						
							|  |  |  |                                     { | 
					
						
							|  |  |  |                                         "file_id": item.get("id"), | 
					
						
							|  |  |  |                                         "name": file_object.filename, | 
					
						
							|  |  |  |                                         "source": file_object.filename, | 
					
						
							|  |  |  |                                     } | 
					
						
							|  |  |  |                                 ] | 
					
						
							|  |  |  |                             ], | 
					
						
							|  |  |  |                         } | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 # Fallback to collection names | 
					
						
							|  |  |  |                 if item.get("legacy"): | 
					
						
							|  |  |  |                     collection_names.append(f"{item['id']}") | 
					
						
							|  |  |  |                 else: | 
					
						
							|  |  |  |                     collection_names.append(f"file-{item['id']}") | 
					
						
							| 
									
										
										
										
											2025-07-11 16:00:21 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-07-11 16:29:17 +08:00
										 |  |  |         elif item.get("type") == "collection": | 
					
						
							|  |  |  |             if ( | 
					
						
							|  |  |  |                 item.get("context") == "full" | 
					
						
							|  |  |  |                 or request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL | 
					
						
							|  |  |  |             ): | 
					
						
							| 
									
										
										
										
											2025-07-09 05:29:49 +08:00
										 |  |  |                 # Manual Full Mode Toggle for Collection | 
					
						
							| 
									
										
										
										
											2025-07-11 16:00:21 +08:00
										 |  |  |                 knowledge_base = Knowledges.get_knowledge_by_id(item.get("id")) | 
					
						
							| 
									
										
										
										
											2025-07-09 05:29:49 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-07-11 16:00:21 +08:00
										 |  |  |                 if knowledge_base and ( | 
					
						
							|  |  |  |                     user.role == "admin" | 
					
						
							|  |  |  |                     or has_access(user.id, "read", knowledge_base.access_control) | 
					
						
							|  |  |  |                 ): | 
					
						
							| 
									
										
										
										
											2025-07-09 05:29:49 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-07-11 16:00:21 +08:00
										 |  |  |                     file_ids = knowledge_base.data.get("file_ids", []) | 
					
						
							| 
									
										
										
										
											2025-07-09 05:29:49 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-07-11 16:00:21 +08:00
										 |  |  |                     documents = [] | 
					
						
							|  |  |  |                     metadatas = [] | 
					
						
							|  |  |  |                     for file_id in file_ids: | 
					
						
							|  |  |  |                         file_object = Files.get_file_by_id(file_id) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                         if file_object: | 
					
						
							|  |  |  |                             documents.append(file_object.data.get("content", "")) | 
					
						
							|  |  |  |                             metadatas.append( | 
					
						
							|  |  |  |                                 { | 
					
						
							|  |  |  |                                     "file_id": file_id, | 
					
						
							|  |  |  |                                     "name": file_object.filename, | 
					
						
							|  |  |  |                                     "source": file_object.filename, | 
					
						
							|  |  |  |                                 } | 
					
						
							|  |  |  |                             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     query_result = { | 
					
						
							|  |  |  |                         "documents": [documents], | 
					
						
							|  |  |  |                         "metadatas": [metadatas], | 
					
						
							|  |  |  |                     } | 
					
						
							| 
									
										
										
										
											2025-07-11 16:29:17 +08:00
										 |  |  |             else: | 
					
						
							|  |  |  |                 # Fallback to collection names | 
					
						
							| 
									
										
										
										
											2025-07-11 16:00:21 +08:00
										 |  |  |                 if item.get("legacy"): | 
					
						
							|  |  |  |                     collection_names = item.get("collection_names", []) | 
					
						
							| 
									
										
										
										
											2024-10-04 14:06:47 +08:00
										 |  |  |                 else: | 
					
						
							| 
									
										
										
										
											2025-07-11 16:00:21 +08:00
										 |  |  |                     collection_names.append(item["id"]) | 
					
						
							| 
									
										
										
										
											2024-05-07 06:49:00 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-07-11 16:29:17 +08:00
										 |  |  |         elif item.get("docs"): | 
					
						
							|  |  |  |             # BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL | 
					
						
							|  |  |  |             query_result = { | 
					
						
							|  |  |  |                 "documents": [[doc.get("content") for doc in item.get("docs")]], | 
					
						
							|  |  |  |                 "metadatas": [[doc.get("metadata") for doc in item.get("docs")]], | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         elif item.get("collection_name"): | 
					
						
							|  |  |  |             # Direct Collection Name | 
					
						
							|  |  |  |             collection_names.append(item["collection_name"]) | 
					
						
							| 
									
										
										
										
											2025-07-16 01:57:24 +08:00
										 |  |  |         elif item.get("collection_names"): | 
					
						
							|  |  |  |             # Collection Names List | 
					
						
							|  |  |  |             collection_names.extend(item["collection_names"]) | 
					
						
							| 
									
										
										
										
											2025-07-11 16:29:17 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         # If query_result is None | 
					
						
							|  |  |  |         # Fallback to collection names and vector search the collections | 
					
						
							|  |  |  |         if query_result is None and collection_names: | 
					
						
							| 
									
										
										
										
											2024-09-30 04:52:27 +08:00
										 |  |  |             collection_names = set(collection_names).difference(extracted_collections) | 
					
						
							|  |  |  |             if not collection_names: | 
					
						
							| 
									
										
										
										
											2025-07-11 16:00:21 +08:00
										 |  |  |                 log.debug(f"skipping {item} as it has already been extracted") | 
					
						
							| 
									
										
										
										
											2024-09-30 04:52:27 +08:00
										 |  |  |                 continue | 
					
						
							| 
									
										
										
										
											2024-04-15 07:48:15 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-07-11 16:35:42 +08:00
										 |  |  |             try: | 
					
						
							|  |  |  |                 if full_context: | 
					
						
							| 
									
										
										
										
											2025-06-25 16:20:08 +08:00
										 |  |  |                     query_result = get_all_items_from_collections(collection_names) | 
					
						
							| 
									
										
										
										
											2025-07-11 16:35:42 +08:00
										 |  |  |                 else: | 
					
						
							|  |  |  |                     query_result = None  # Initialize to None | 
					
						
							|  |  |  |                     if hybrid_search: | 
					
						
							|  |  |  |                         try: | 
					
						
							|  |  |  |                             query_result = query_collection_with_hybrid_search( | 
					
						
							| 
									
										
										
										
											2024-09-30 04:52:27 +08:00
										 |  |  |                                 collection_names=collection_names, | 
					
						
							| 
									
										
										
										
											2024-11-19 18:24:32 +08:00
										 |  |  |                                 queries=queries, | 
					
						
							| 
									
										
										
										
											2024-09-30 04:52:27 +08:00
										 |  |  |                                 embedding_function=embedding_function, | 
					
						
							|  |  |  |                                 k=k, | 
					
						
							| 
									
										
										
										
											2025-07-11 16:35:42 +08:00
										 |  |  |                                 reranking_function=reranking_function, | 
					
						
							|  |  |  |                                 k_reranker=k_reranker, | 
					
						
							|  |  |  |                                 r=r, | 
					
						
							|  |  |  |                                 hybrid_bm25_weight=hybrid_bm25_weight, | 
					
						
							|  |  |  |                             ) | 
					
						
							|  |  |  |                         except Exception as e: | 
					
						
							|  |  |  |                             log.debug( | 
					
						
							|  |  |  |                                 "Error when using hybrid search, using non hybrid search as fallback." | 
					
						
							| 
									
										
										
										
											2024-09-30 04:52:27 +08:00
										 |  |  |                             ) | 
					
						
							| 
									
										
										
										
											2025-07-11 16:35:42 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |                     # fallback to non-hybrid search | 
					
						
							|  |  |  |                     if not hybrid_search and query_result is None: | 
					
						
							|  |  |  |                         query_result = query_collection( | 
					
						
							|  |  |  |                             collection_names=collection_names, | 
					
						
							|  |  |  |                             queries=queries, | 
					
						
							|  |  |  |                             embedding_function=embedding_function, | 
					
						
							|  |  |  |                             k=k, | 
					
						
							|  |  |  |                         ) | 
					
						
							|  |  |  |             except Exception as e: | 
					
						
							|  |  |  |                 log.exception(e) | 
					
						
							| 
									
										
										
										
											2024-09-12 21:58:26 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-30 04:52:27 +08:00
										 |  |  |             extracted_collections.extend(collection_names) | 
					
						
							| 
									
										
										
										
											2024-03-11 09:40:50 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-06-25 16:20:08 +08:00
										 |  |  |         if query_result: | 
					
						
							| 
									
										
										
										
											2025-07-11 16:00:21 +08:00
										 |  |  |             if "data" in item: | 
					
						
							|  |  |  |                 del item["data"] | 
					
						
							|  |  |  |             query_results.append({**query_result, "file": item}) | 
					
						
							| 
									
										
										
										
											2024-03-11 09:40:50 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-22 11:46:09 +08:00
										 |  |  |     sources = [] | 
					
						
							| 
									
										
										
										
											2025-06-25 16:20:08 +08:00
										 |  |  |     for query_result in query_results: | 
					
						
							| 
									
										
										
										
											2024-04-30 13:51:30 +08:00
										 |  |  |         try: | 
					
						
							| 
									
										
										
										
											2025-06-25 16:20:08 +08:00
										 |  |  |             if "documents" in query_result: | 
					
						
							|  |  |  |                 if "metadatas" in query_result: | 
					
						
							| 
									
										
										
										
											2024-11-22 11:46:09 +08:00
										 |  |  |                     source = { | 
					
						
							| 
									
										
										
										
											2025-06-25 16:20:08 +08:00
										 |  |  |                         "source": query_result["file"], | 
					
						
							|  |  |  |                         "document": query_result["documents"][0], | 
					
						
							|  |  |  |                         "metadata": query_result["metadatas"][0], | 
					
						
							| 
									
										
										
										
											2024-10-08 03:13:13 +08:00
										 |  |  |                     } | 
					
						
							| 
									
										
										
										
											2025-06-25 16:20:08 +08:00
										 |  |  |                     if "distances" in query_result and query_result["distances"]: | 
					
						
							|  |  |  |                         source["distances"] = query_result["distances"][0] | 
					
						
							| 
									
										
										
										
											2024-11-22 11:46:09 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |                     sources.append(source) | 
					
						
							| 
									
										
										
										
											2024-04-30 13:51:30 +08:00
										 |  |  |         except Exception as e: | 
					
						
							|  |  |  |             log.exception(e) | 
					
						
							| 
									
										
										
										
											2024-05-07 06:14:33 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-22 11:46:09 +08:00
										 |  |  |     return sources | 
					
						
							| 
									
										
										
										
											2024-04-05 01:01:23 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-05 02:07:42 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-25 20:49:59 +08:00
										 |  |  | def get_model_path(model: str, update_model: bool = False): | 
					
						
							|  |  |  |     # Construct huggingface_hub kwargs with local_files_only to return the snapshot path | 
					
						
							|  |  |  |     cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     local_files_only = not update_model | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-29 14:23:09 +08:00
										 |  |  |     if OFFLINE_MODE: | 
					
						
							|  |  |  |         local_files_only = True | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-25 20:49:59 +08:00
										 |  |  |     snapshot_kwargs = { | 
					
						
							|  |  |  |         "cache_dir": cache_dir, | 
					
						
							|  |  |  |         "local_files_only": local_files_only, | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-26 02:28:31 +08:00
										 |  |  |     log.debug(f"model: {model}") | 
					
						
							| 
									
										
										
										
											2024-04-25 20:49:59 +08:00
										 |  |  |     log.debug(f"snapshot_kwargs: {snapshot_kwargs}") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Inspiration from upstream sentence_transformers | 
					
						
							|  |  |  |     if ( | 
					
						
							| 
									
										
										
										
											2024-07-15 17:09:05 +08:00
										 |  |  |         os.path.exists(model) | 
					
						
							| 
									
										
										
										
											2024-04-25 20:49:59 +08:00
										 |  |  |         or ("\\" in model or model.count("/") > 1) | 
					
						
							|  |  |  |         and local_files_only | 
					
						
							|  |  |  |     ): | 
					
						
							|  |  |  |         # If fully qualified path exists, return input, else set repo_id | 
					
						
							| 
									
										
										
										
											2024-07-15 17:09:05 +08:00
										 |  |  |         return model | 
					
						
							| 
									
										
										
										
											2024-04-25 20:49:59 +08:00
										 |  |  |     elif "/" not in model: | 
					
						
							|  |  |  |         # Set valid repo_id for model short-name | 
					
						
							|  |  |  |         model = "sentence-transformers" + "/" + model | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     snapshot_kwargs["repo_id"] = model | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Attempt to query the huggingface_hub library to determine the local path and/or to update | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         model_repo_path = snapshot_download(**snapshot_kwargs) | 
					
						
							|  |  |  |         log.debug(f"model_repo_path: {model_repo_path}") | 
					
						
							|  |  |  |         return model_repo_path | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							|  |  |  |         log.exception(f"Cannot determine model snapshot path: {e}") | 
					
						
							| 
									
										
										
										
											2024-07-15 17:09:05 +08:00
										 |  |  |         return model | 
					
						
							| 
									
										
										
										
											2024-04-25 20:49:59 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-19 06:19:56 +08:00
										 |  |  | def generate_openai_batch_embeddings( | 
					
						
							| 
									
										
										
										
											2025-02-05 16:07:45 +08:00
										 |  |  |     model: str, | 
					
						
							|  |  |  |     texts: list[str], | 
					
						
							|  |  |  |     url: str = "https://api.openai.com/v1", | 
					
						
							|  |  |  |     key: str = "", | 
					
						
							| 
									
										
										
										
											2025-02-06 07:15:24 +08:00
										 |  |  |     prefix: str = None, | 
					
						
							| 
									
										
										
										
											2025-03-31 12:55:15 +08:00
										 |  |  |     user: UserModel = None, | 
					
						
							| 
									
										
										
										
											2024-06-02 22:34:31 +08:00
										 |  |  | ) -> Optional[list[list[float]]]: | 
					
						
							| 
									
										
										
										
											2024-04-15 07:15:39 +08:00
										 |  |  |     try: | 
					
						
							| 
									
										
										
										
											2025-04-13 07:35:11 +08:00
										 |  |  |         log.debug( | 
					
						
							|  |  |  |             f"generate_openai_batch_embeddings:model {model} batch size: {len(texts)}" | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2025-03-31 12:55:15 +08:00
										 |  |  |         json_data = {"input": texts, "model": model} | 
					
						
							|  |  |  |         if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str): | 
					
						
							| 
									
										
										
										
											2025-02-06 06:03:16 +08:00
										 |  |  |             json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-15 07:15:39 +08:00
										 |  |  |         r = requests.post( | 
					
						
							| 
									
										
										
										
											2024-04-21 04:15:59 +08:00
										 |  |  |             f"{url}/embeddings", | 
					
						
							| 
									
										
										
										
											2024-04-15 07:15:39 +08:00
										 |  |  |             headers={ | 
					
						
							|  |  |  |                 "Content-Type": "application/json", | 
					
						
							|  |  |  |                 "Authorization": f"Bearer {key}", | 
					
						
							| 
									
										
										
										
											2025-01-29 18:55:52 +08:00
										 |  |  |                 **( | 
					
						
							|  |  |  |                     { | 
					
						
							| 
									
										
										
										
											2025-07-11 03:00:14 +08:00
										 |  |  |                         "X-OpenWebUI-User-Name": quote(user.name, safe=" "), | 
					
						
							|  |  |  |                         "X-OpenWebUI-User-Id": user.id, | 
					
						
							|  |  |  |                         "X-OpenWebUI-User-Email": user.email, | 
					
						
							|  |  |  |                         "X-OpenWebUI-User-Role": user.role, | 
					
						
							| 
									
										
										
										
											2025-01-29 18:55:52 +08:00
										 |  |  |                     } | 
					
						
							|  |  |  |                     if ENABLE_FORWARD_USER_INFO_HEADERS and user | 
					
						
							|  |  |  |                     else {} | 
					
						
							|  |  |  |                 ), | 
					
						
							| 
									
										
										
										
											2024-04-15 07:15:39 +08:00
										 |  |  |             }, | 
					
						
							| 
									
										
										
										
											2025-02-06 06:03:16 +08:00
										 |  |  |             json=json_data, | 
					
						
							| 
									
										
										
										
											2024-04-15 07:15:39 +08:00
										 |  |  |         ) | 
					
						
							|  |  |  |         r.raise_for_status() | 
					
						
							|  |  |  |         data = r.json() | 
					
						
							|  |  |  |         if "data" in data: | 
					
						
							| 
									
										
										
										
											2024-06-02 22:34:31 +08:00
										 |  |  |             return [elem["embedding"] for elem in data["data"]] | 
					
						
							| 
									
										
										
										
											2024-04-15 07:15:39 +08:00
										 |  |  |         else: | 
					
						
							|  |  |  |             raise "Something went wrong :/" | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							| 
									
										
										
										
											2025-02-25 22:36:25 +08:00
										 |  |  |         log.exception(f"Error generating openai batch embeddings: {e}") | 
					
						
							| 
									
										
										
										
											2024-04-15 07:15:39 +08:00
										 |  |  |         return None | 
					
						
							| 
									
										
										
										
											2024-04-23 04:49:58 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-05-20 10:58:04 +08:00
										 |  |  | def generate_azure_openai_batch_embeddings( | 
					
						
							| 
									
										
										
										
											2025-05-30 04:34:18 +08:00
										 |  |  |     model: str, | 
					
						
							| 
									
										
										
										
											2025-05-20 10:58:04 +08:00
										 |  |  |     texts: list[str], | 
					
						
							|  |  |  |     url: str, | 
					
						
							|  |  |  |     key: str = "", | 
					
						
							|  |  |  |     version: str = "", | 
					
						
							|  |  |  |     prefix: str = None, | 
					
						
							|  |  |  |     user: UserModel = None, | 
					
						
							|  |  |  | ) -> Optional[list[list[float]]]: | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         log.debug( | 
					
						
							| 
									
										
										
										
											2025-05-30 04:34:18 +08:00
										 |  |  |             f"generate_azure_openai_batch_embeddings:deployment {model} batch size: {len(texts)}" | 
					
						
							| 
									
										
										
										
											2025-05-20 10:58:04 +08:00
										 |  |  |         ) | 
					
						
							| 
									
										
										
										
											2025-05-30 04:34:18 +08:00
										 |  |  |         json_data = {"input": texts} | 
					
						
							| 
									
										
										
										
											2025-05-20 10:58:04 +08:00
										 |  |  |         if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str): | 
					
						
							|  |  |  |             json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-05-30 04:34:18 +08:00
										 |  |  |         url = f"{url}/openai/deployments/{model}/embeddings?api-version={version}" | 
					
						
							| 
									
										
										
										
											2025-05-20 10:58:04 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         for _ in range(5): | 
					
						
							|  |  |  |             r = requests.post( | 
					
						
							|  |  |  |                 url, | 
					
						
							|  |  |  |                 headers={ | 
					
						
							|  |  |  |                     "Content-Type": "application/json", | 
					
						
							|  |  |  |                     "api-key": key, | 
					
						
							|  |  |  |                     **( | 
					
						
							|  |  |  |                         { | 
					
						
							| 
									
										
										
										
											2025-07-11 03:00:14 +08:00
										 |  |  |                             "X-OpenWebUI-User-Name": quote(user.name, safe=" "), | 
					
						
							|  |  |  |                             "X-OpenWebUI-User-Id": user.id, | 
					
						
							|  |  |  |                             "X-OpenWebUI-User-Email": user.email, | 
					
						
							|  |  |  |                             "X-OpenWebUI-User-Role": user.role, | 
					
						
							| 
									
										
										
										
											2025-05-20 10:58:04 +08:00
										 |  |  |                         } | 
					
						
							|  |  |  |                         if ENABLE_FORWARD_USER_INFO_HEADERS and user | 
					
						
							|  |  |  |                         else {} | 
					
						
							|  |  |  |                     ), | 
					
						
							|  |  |  |                 }, | 
					
						
							|  |  |  |                 json=json_data, | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             if r.status_code == 429: | 
					
						
							|  |  |  |                 retry = float(r.headers.get("Retry-After", "1")) | 
					
						
							|  |  |  |                 time.sleep(retry) | 
					
						
							|  |  |  |                 continue | 
					
						
							|  |  |  |             r.raise_for_status() | 
					
						
							|  |  |  |             data = r.json() | 
					
						
							|  |  |  |             if "data" in data: | 
					
						
							|  |  |  |                 return [elem["embedding"] for elem in data["data"]] | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 raise Exception("Something went wrong :/") | 
					
						
							|  |  |  |         return None | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							|  |  |  |         log.exception(f"Error generating azure openai batch embeddings: {e}") | 
					
						
							|  |  |  |         return None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-19 06:19:56 +08:00
										 |  |  | def generate_ollama_batch_embeddings( | 
					
						
							| 
									
										
										
										
											2025-03-31 12:55:15 +08:00
										 |  |  |     model: str, | 
					
						
							| 
									
										
										
										
											2025-02-06 07:15:24 +08:00
										 |  |  |     texts: list[str], | 
					
						
							|  |  |  |     url: str, | 
					
						
							| 
									
										
										
										
											2025-03-31 12:55:15 +08:00
										 |  |  |     key: str = "", | 
					
						
							|  |  |  |     prefix: str = None, | 
					
						
							|  |  |  |     user: UserModel = None, | 
					
						
							| 
									
										
										
										
											2024-11-19 06:19:56 +08:00
										 |  |  | ) -> Optional[list[list[float]]]: | 
					
						
							|  |  |  |     try: | 
					
						
							| 
									
										
										
										
											2025-04-13 07:35:11 +08:00
										 |  |  |         log.debug( | 
					
						
							|  |  |  |             f"generate_ollama_batch_embeddings:model {model} batch size: {len(texts)}" | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2025-03-31 12:55:15 +08:00
										 |  |  |         json_data = {"input": texts, "model": model} | 
					
						
							|  |  |  |         if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str): | 
					
						
							| 
									
										
										
										
											2025-02-06 06:03:16 +08:00
										 |  |  |             json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix | 
					
						
							| 
									
										
										
										
											2025-03-31 12:55:15 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-31 08:55:29 +08:00
										 |  |  |         r = requests.post( | 
					
						
							|  |  |  |             f"{url}/api/embed", | 
					
						
							|  |  |  |             headers={ | 
					
						
							|  |  |  |                 "Content-Type": "application/json", | 
					
						
							|  |  |  |                 "Authorization": f"Bearer {key}", | 
					
						
							| 
									
										
										
										
											2025-01-29 18:55:52 +08:00
										 |  |  |                 **( | 
					
						
							|  |  |  |                     { | 
					
						
							| 
									
										
										
										
											2025-07-11 03:00:14 +08:00
										 |  |  |                         "X-OpenWebUI-User-Name": quote(user.name, safe=" "), | 
					
						
							|  |  |  |                         "X-OpenWebUI-User-Id": user.id, | 
					
						
							|  |  |  |                         "X-OpenWebUI-User-Email": user.email, | 
					
						
							|  |  |  |                         "X-OpenWebUI-User-Role": user.role, | 
					
						
							| 
									
										
										
										
											2025-01-29 18:55:52 +08:00
										 |  |  |                     } | 
					
						
							|  |  |  |                     if ENABLE_FORWARD_USER_INFO_HEADERS | 
					
						
							|  |  |  |                     else {} | 
					
						
							|  |  |  |                 ), | 
					
						
							| 
									
										
										
										
											2024-12-31 08:55:29 +08:00
										 |  |  |             }, | 
					
						
							| 
									
										
										
										
											2025-02-06 06:03:16 +08:00
										 |  |  |             json=json_data, | 
					
						
							| 
									
										
										
										
											2024-12-31 08:55:29 +08:00
										 |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-11-19 06:19:56 +08:00
										 |  |  |         r.raise_for_status() | 
					
						
							| 
									
										
										
										
											2024-12-31 08:55:29 +08:00
										 |  |  |         data = r.json() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if "embeddings" in data: | 
					
						
							|  |  |  |             return data["embeddings"] | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             raise "Something went wrong :/" | 
					
						
							| 
									
										
										
										
											2024-11-19 06:19:56 +08:00
										 |  |  |     except Exception as e: | 
					
						
							| 
									
										
										
										
											2025-02-25 22:36:25 +08:00
										 |  |  |         log.exception(f"Error generating ollama batch embeddings: {e}") | 
					
						
							| 
									
										
										
										
											2024-11-19 06:19:56 +08:00
										 |  |  |         return None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-03-31 12:55:15 +08:00
										 |  |  | def generate_embeddings( | 
					
						
							|  |  |  |     engine: str, | 
					
						
							|  |  |  |     model: str, | 
					
						
							|  |  |  |     text: Union[str, list[str]], | 
					
						
							|  |  |  |     prefix: Union[str, None] = None, | 
					
						
							|  |  |  |     **kwargs, | 
					
						
							|  |  |  | ): | 
					
						
							| 
									
										
										
										
											2024-11-19 06:19:56 +08:00
										 |  |  |     url = kwargs.get("url", "") | 
					
						
							|  |  |  |     key = kwargs.get("key", "") | 
					
						
							| 
									
										
										
										
											2025-01-29 18:55:52 +08:00
										 |  |  |     user = kwargs.get("user") | 
					
						
							| 
									
										
										
										
											2024-11-19 06:19:56 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-06 06:03:16 +08:00
										 |  |  |     if prefix is not None and RAG_EMBEDDING_PREFIX_FIELD_NAME is None: | 
					
						
							|  |  |  |         if isinstance(text, list): | 
					
						
							| 
									
										
										
										
											2025-03-31 12:55:15 +08:00
										 |  |  |             text = [f"{prefix}{text_element}" for text_element in text] | 
					
						
							| 
									
										
										
										
											2025-02-06 06:03:16 +08:00
										 |  |  |         else: | 
					
						
							| 
									
										
										
										
											2025-03-31 12:55:15 +08:00
										 |  |  |             text = f"{prefix}{text}" | 
					
						
							| 
									
										
										
										
											2025-02-06 06:03:16 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-10 02:41:35 +08:00
										 |  |  |     if engine == "ollama": | 
					
						
							| 
									
										
										
										
											2025-05-30 05:19:56 +08:00
										 |  |  |         embeddings = generate_ollama_batch_embeddings( | 
					
						
							|  |  |  |             **{ | 
					
						
							|  |  |  |                 "model": model, | 
					
						
							|  |  |  |                 "texts": text if isinstance(text, list) else [text], | 
					
						
							|  |  |  |                 "url": url, | 
					
						
							|  |  |  |                 "key": key, | 
					
						
							|  |  |  |                 "prefix": prefix, | 
					
						
							|  |  |  |                 "user": user, | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-11-19 06:19:56 +08:00
										 |  |  |         return embeddings[0] if isinstance(text, str) else embeddings | 
					
						
							| 
									
										
										
										
											2024-10-10 02:41:35 +08:00
										 |  |  |     elif engine == "openai": | 
					
						
							| 
									
										
										
										
											2025-05-30 05:19:56 +08:00
										 |  |  |         embeddings = generate_openai_batch_embeddings( | 
					
						
							|  |  |  |             model, text if isinstance(text, list) else [text], url, key, prefix, user | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-10-10 02:41:35 +08:00
										 |  |  |         return embeddings[0] if isinstance(text, str) else embeddings | 
					
						
							| 
									
										
										
										
											2025-05-20 10:58:04 +08:00
										 |  |  |     elif engine == "azure_openai": | 
					
						
							| 
									
										
										
										
											2025-05-30 04:34:18 +08:00
										 |  |  |         azure_api_version = kwargs.get("azure_api_version", "") | 
					
						
							| 
									
										
										
										
											2025-05-30 05:19:56 +08:00
										 |  |  |         embeddings = generate_azure_openai_batch_embeddings( | 
					
						
							|  |  |  |             model, | 
					
						
							|  |  |  |             text if isinstance(text, list) else [text], | 
					
						
							|  |  |  |             url, | 
					
						
							|  |  |  |             key, | 
					
						
							|  |  |  |             azure_api_version, | 
					
						
							|  |  |  |             prefix, | 
					
						
							|  |  |  |             user, | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2025-05-20 10:58:04 +08:00
										 |  |  |         return embeddings[0] if isinstance(text, str) else embeddings | 
					
						
							| 
									
										
										
										
											2024-09-27 06:28:47 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-23 07:36:46 +08:00
										 |  |  | import operator | 
					
						
							|  |  |  | from typing import Optional, Sequence | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from langchain_core.callbacks import Callbacks | 
					
						
							| 
									
										
										
										
											2024-08-28 06:10:27 +08:00
										 |  |  | from langchain_core.documents import BaseDocumentCompressor, Document | 
					
						
							| 
									
										
										
										
											2024-04-23 07:36:46 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class RerankCompressor(BaseDocumentCompressor): | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  |     embedding_function: Any | 
					
						
							| 
									
										
										
										
											2024-04-30 01:15:58 +08:00
										 |  |  |     top_n: int | 
					
						
							| 
									
										
										
										
											2024-04-23 07:36:46 +08:00
										 |  |  |     reranking_function: Any | 
					
						
							|  |  |  |     r_score: float | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     class Config: | 
					
						
							| 
									
										
										
										
											2024-09-19 23:05:49 +08:00
										 |  |  |         extra = "forbid" | 
					
						
							| 
									
										
										
										
											2024-04-23 07:36:46 +08:00
										 |  |  |         arbitrary_types_allowed = True | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def compress_documents( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         documents: Sequence[Document], | 
					
						
							|  |  |  |         query: str, | 
					
						
							|  |  |  |         callbacks: Optional[Callbacks] = None, | 
					
						
							|  |  |  |     ) -> Sequence[Document]: | 
					
						
							| 
									
										
										
										
											2024-04-30 01:15:58 +08:00
										 |  |  |         reranking = self.reranking_function is not None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if reranking: | 
					
						
							| 
									
										
										
										
											2025-07-14 17:59:10 +08:00
										 |  |  |             scores = self.reranking_function( | 
					
						
							| 
									
										
										
										
											2024-04-23 07:36:46 +08:00
										 |  |  |                 [(query, doc.page_content) for doc in documents] | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |         else: | 
					
						
							| 
									
										
										
										
											2024-07-01 08:13:56 +08:00
										 |  |  |             from sentence_transformers import util | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-16 09:05:04 +08:00
										 |  |  |             query_embedding = self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX) | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  |             document_embedding = self.embedding_function( | 
					
						
							| 
									
										
										
										
											2025-03-31 12:55:15 +08:00
										 |  |  |                 [doc.page_content for doc in documents], RAG_EMBEDDING_CONTENT_PREFIX | 
					
						
							| 
									
										
										
										
											2024-04-23 07:36:46 +08:00
										 |  |  |             ) | 
					
						
							|  |  |  |             scores = util.cos_sim(query_embedding, document_embedding)[0] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-05-10 23:33:34 +08:00
										 |  |  |         docs_with_scores = list( | 
					
						
							|  |  |  |             zip(documents, scores.tolist() if not isinstance(scores, list) else scores) | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-04-23 07:36:46 +08:00
										 |  |  |         if self.r_score: | 
					
						
							|  |  |  |             docs_with_scores = [ | 
					
						
							|  |  |  |                 (d, s) for d, s in docs_with_scores if s >= self.r_score | 
					
						
							|  |  |  |             ] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-30 01:15:58 +08:00
										 |  |  |         result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True) | 
					
						
							| 
									
										
										
										
											2024-04-23 07:36:46 +08:00
										 |  |  |         final_results = [] | 
					
						
							|  |  |  |         for doc, doc_score in result[: self.top_n]: | 
					
						
							|  |  |  |             metadata = doc.metadata | 
					
						
							|  |  |  |             metadata["score"] = doc_score | 
					
						
							|  |  |  |             doc = Document( | 
					
						
							|  |  |  |                 page_content=doc.page_content, | 
					
						
							|  |  |  |                 metadata=metadata, | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             final_results.append(doc) | 
					
						
							|  |  |  |         return final_results |