| 
									
										
										
										
											2024-03-21 07:11:36 +08:00
										 |  |  | import logging | 
					
						
							| 
									
										
										
										
											2024-08-28 06:10:27 +08:00
										 |  |  | import os | 
					
						
							| 
									
										
										
										
											2024-09-12 21:19:40 +08:00
										 |  |  | import uuid | 
					
						
							| 
									
										
										
										
											2024-08-28 06:10:27 +08:00
										 |  |  | from typing import Optional, Union | 
					
						
							| 
									
										
										
										
											2024-03-09 11:26:39 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-28 06:10:27 +08:00
										 |  |  | import requests | 
					
						
							| 
									
										
										
										
											2024-09-10 09:27:50 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-25 20:49:59 +08:00
										 |  |  | from huggingface_hub import snapshot_download | 
					
						
							| 
									
										
										
										
											2024-08-28 06:10:27 +08:00
										 |  |  | from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever | 
					
						
							| 
									
										
										
										
											2024-04-23 07:36:46 +08:00
										 |  |  | from langchain_community.retrievers import BM25Retriever | 
					
						
							| 
									
										
										
										
											2024-08-28 06:10:27 +08:00
										 |  |  | from langchain_core.documents import Document | 
					
						
							| 
									
										
										
										
											2024-09-10 09:27:50 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from open_webui.apps.ollama.main import ( | 
					
						
							|  |  |  |     GenerateEmbeddingsForm, | 
					
						
							|  |  |  |     generate_ollama_embeddings, | 
					
						
							|  |  |  | ) | 
					
						
							| 
									
										
										
										
											2024-09-28 07:28:45 +08:00
										 |  |  | from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT | 
					
						
							| 
									
										
										
										
											2024-09-04 22:54:48 +08:00
										 |  |  | from open_webui.utils.misc import get_last_user_message | 
					
						
							| 
									
										
										
										
											2024-04-15 07:48:15 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-10 09:27:50 +08:00
										 |  |  | from open_webui.env import SRC_LOG_LEVELS | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-21 07:11:36 +08:00
										 |  |  | log = logging.getLogger(__name__) | 
					
						
							|  |  |  | log.setLevel(SRC_LOG_LEVELS["RAG"]) | 
					
						
							| 
									
										
										
										
											2024-03-09 11:26:39 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-10 11:37:06 +08:00
										 |  |  | from typing import Any | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from langchain_core.callbacks import CallbackManagerForRetrieverRun | 
					
						
							|  |  |  | from langchain_core.retrievers import BaseRetriever | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class VectorSearchRetriever(BaseRetriever): | 
					
						
							|  |  |  |     collection_name: Any | 
					
						
							|  |  |  |     embedding_function: Any | 
					
						
							|  |  |  |     top_k: int | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _get_relevant_documents( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         query: str, | 
					
						
							|  |  |  |         *, | 
					
						
							|  |  |  |         run_manager: CallbackManagerForRetrieverRun, | 
					
						
							|  |  |  |     ) -> list[Document]: | 
					
						
							|  |  |  |         result = VECTOR_DB_CLIENT.search( | 
					
						
							|  |  |  |             collection_name=self.collection_name, | 
					
						
							|  |  |  |             vectors=[self.embedding_function(query)], | 
					
						
							|  |  |  |             limit=self.top_k, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-13 13:18:20 +08:00
										 |  |  |         ids = result.ids[0] | 
					
						
							|  |  |  |         metadatas = result.metadatas[0] | 
					
						
							|  |  |  |         documents = result.documents[0] | 
					
						
							| 
									
										
										
										
											2024-09-10 11:37:06 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         results = [] | 
					
						
							|  |  |  |         for idx in range(len(ids)): | 
					
						
							|  |  |  |             results.append( | 
					
						
							|  |  |  |                 Document( | 
					
						
							|  |  |  |                     metadata=metadatas[idx], | 
					
						
							|  |  |  |                     page_content=documents[idx], | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |         return results | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  | def query_doc( | 
					
						
							| 
									
										
										
										
											2024-04-23 04:49:58 +08:00
										 |  |  |     collection_name: str, | 
					
						
							| 
									
										
										
										
											2024-09-30 22:18:02 +08:00
										 |  |  |     query_embedding: list[float], | 
					
						
							| 
									
										
										
										
											2024-04-26 05:03:00 +08:00
										 |  |  |     k: int, | 
					
						
							| 
									
										
										
										
											2024-04-23 04:49:58 +08:00
										 |  |  | ): | 
					
						
							| 
									
										
										
										
											2024-04-15 05:55:00 +08:00
										 |  |  |     try: | 
					
						
							| 
									
										
										
										
											2024-09-10 11:37:06 +08:00
										 |  |  |         result = VECTOR_DB_CLIENT.search( | 
					
						
							|  |  |  |             collection_name=collection_name, | 
					
						
							| 
									
										
										
										
											2024-09-30 22:18:02 +08:00
										 |  |  |             vectors=[query_embedding], | 
					
						
							| 
									
										
										
										
											2024-09-10 11:37:06 +08:00
										 |  |  |             limit=k, | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-04-26 05:03:00 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  |         log.info(f"query_doc:result {result}") | 
					
						
							|  |  |  |         return result | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							| 
									
										
										
										
											2024-09-10 11:37:06 +08:00
										 |  |  |         print(e) | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  |         raise e | 
					
						
							| 
									
										
										
										
											2024-04-26 05:03:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  | def query_doc_with_hybrid_search( | 
					
						
							|  |  |  |     collection_name: str, | 
					
						
							|  |  |  |     query: str, | 
					
						
							|  |  |  |     embedding_function, | 
					
						
							|  |  |  |     k: int, | 
					
						
							|  |  |  |     reranking_function, | 
					
						
							| 
									
										
										
										
											2024-05-02 13:45:19 +08:00
										 |  |  |     r: float, | 
					
						
							| 
									
										
										
										
											2024-09-12 21:50:18 +08:00
										 |  |  | ) -> dict: | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  |     try: | 
					
						
							| 
									
										
										
										
											2024-09-10 11:37:06 +08:00
										 |  |  |         result = VECTOR_DB_CLIENT.get(collection_name=collection_name) | 
					
						
							| 
									
										
										
										
											2024-04-26 05:03:00 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  |         bm25_retriever = BM25Retriever.from_texts( | 
					
						
							| 
									
										
										
										
											2024-09-13 13:21:47 +08:00
										 |  |  |             texts=result.documents[0], | 
					
						
							|  |  |  |             metadatas=result.metadatas[0], | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  |         ) | 
					
						
							|  |  |  |         bm25_retriever.k = k | 
					
						
							| 
									
										
										
										
											2024-04-26 05:03:00 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-10 11:37:06 +08:00
										 |  |  |         vector_search_retriever = VectorSearchRetriever( | 
					
						
							|  |  |  |             collection_name=collection_name, | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  |             embedding_function=embedding_function, | 
					
						
							| 
									
										
										
										
											2024-09-10 11:37:06 +08:00
										 |  |  |             top_k=k, | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-04-26 05:03:00 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  |         ensemble_retriever = EnsembleRetriever( | 
					
						
							| 
									
										
										
										
											2024-09-10 11:37:06 +08:00
										 |  |  |             retrievers=[bm25_retriever, vector_search_retriever], weights=[0.5, 0.5] | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  |         ) | 
					
						
							|  |  |  |         compressor = RerankCompressor( | 
					
						
							|  |  |  |             embedding_function=embedding_function, | 
					
						
							| 
									
										
										
										
											2024-04-30 01:15:58 +08:00
										 |  |  |             top_n=k, | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  |             reranking_function=reranking_function, | 
					
						
							|  |  |  |             r_score=r, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         compression_retriever = ContextualCompressionRetriever( | 
					
						
							|  |  |  |             base_compressor=compressor, base_retriever=ensemble_retriever | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-04-26 05:03:00 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  |         result = compression_retriever.invoke(query) | 
					
						
							|  |  |  |         result = { | 
					
						
							|  |  |  |             "distances": [[d.metadata.get("score") for d in result]], | 
					
						
							|  |  |  |             "documents": [[d.page_content for d in result]], | 
					
						
							|  |  |  |             "metadatas": [[d.metadata for d in result]], | 
					
						
							|  |  |  |         } | 
					
						
							| 
									
										
										
										
											2024-04-30 01:15:58 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  |         log.info(f"query_doc_with_hybrid_search:result {result}") | 
					
						
							| 
									
										
										
										
											2024-04-15 05:55:00 +08:00
										 |  |  |         return result | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							|  |  |  |         raise e | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-13 12:48:54 +08:00
										 |  |  | def merge_and_sort_query_results( | 
					
						
							|  |  |  |     query_results: list[dict], k: int, reverse: bool = False | 
					
						
							|  |  |  | ) -> list[dict]: | 
					
						
							| 
									
										
										
										
											2024-03-09 11:26:39 +08:00
										 |  |  |     # Initialize lists to store combined data | 
					
						
							|  |  |  |     combined_distances = [] | 
					
						
							|  |  |  |     combined_documents = [] | 
					
						
							| 
									
										
										
										
											2024-04-23 04:49:58 +08:00
										 |  |  |     combined_metadatas = [] | 
					
						
							| 
									
										
										
										
											2024-03-09 11:26:39 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     for data in query_results: | 
					
						
							|  |  |  |         combined_distances.extend(data["distances"][0]) | 
					
						
							|  |  |  |         combined_documents.extend(data["documents"][0]) | 
					
						
							| 
									
										
										
										
											2024-04-23 04:49:58 +08:00
										 |  |  |         combined_metadatas.extend(data["metadatas"][0]) | 
					
						
							| 
									
										
										
										
											2024-03-09 11:26:39 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-23 04:49:58 +08:00
										 |  |  |     # Create a list of tuples (distance, document, metadata) | 
					
						
							| 
									
										
										
										
											2024-04-23 07:36:46 +08:00
										 |  |  |     combined = list(zip(combined_distances, combined_documents, combined_metadatas)) | 
					
						
							| 
									
										
										
										
											2024-03-09 11:26:39 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # Sort the list based on distances | 
					
						
							| 
									
										
										
										
											2024-04-26 09:00:47 +08:00
										 |  |  |     combined.sort(key=lambda x: x[0], reverse=reverse) | 
					
						
							| 
									
										
										
										
											2024-03-09 11:26:39 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-23 07:36:46 +08:00
										 |  |  |     # We don't have anything :-( | 
					
						
							|  |  |  |     if not combined: | 
					
						
							|  |  |  |         sorted_distances = [] | 
					
						
							|  |  |  |         sorted_documents = [] | 
					
						
							|  |  |  |         sorted_metadatas = [] | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         # Unzip the sorted list | 
					
						
							|  |  |  |         sorted_distances, sorted_documents, sorted_metadatas = zip(*combined) | 
					
						
							| 
									
										
										
										
											2024-03-09 11:26:39 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-23 07:36:46 +08:00
										 |  |  |         # Slicing the lists to include only k elements | 
					
						
							|  |  |  |         sorted_distances = list(sorted_distances)[:k] | 
					
						
							|  |  |  |         sorted_documents = list(sorted_documents)[:k] | 
					
						
							|  |  |  |         sorted_metadatas = list(sorted_metadatas)[:k] | 
					
						
							| 
									
										
										
										
											2024-03-09 11:26:39 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # Create the output dictionary | 
					
						
							| 
									
										
										
										
											2024-04-23 07:36:46 +08:00
										 |  |  |     result = { | 
					
						
							| 
									
										
										
										
											2024-03-09 11:26:39 +08:00
										 |  |  |         "distances": [sorted_distances], | 
					
						
							|  |  |  |         "documents": [sorted_documents], | 
					
						
							| 
									
										
										
										
											2024-04-23 04:49:58 +08:00
										 |  |  |         "metadatas": [sorted_metadatas], | 
					
						
							| 
									
										
										
										
											2024-03-09 11:26:39 +08:00
										 |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-23 07:36:46 +08:00
										 |  |  |     return result | 
					
						
							| 
									
										
										
										
											2024-03-09 11:26:39 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  | def query_collection( | 
					
						
							| 
									
										
										
										
											2024-08-14 20:46:31 +08:00
										 |  |  |     collection_names: list[str], | 
					
						
							| 
									
										
										
										
											2024-04-23 04:49:58 +08:00
										 |  |  |     query: str, | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  |     embedding_function, | 
					
						
							|  |  |  |     k: int, | 
					
						
							| 
									
										
										
										
											2024-09-12 21:50:18 +08:00
										 |  |  | ) -> dict: | 
					
						
							| 
									
										
										
										
											2024-10-04 16:04:04 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  |     results = [] | 
					
						
							| 
									
										
										
										
											2024-09-30 22:18:02 +08:00
										 |  |  |     query_embedding = embedding_function(query) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  |     for collection_name in collection_names: | 
					
						
							| 
									
										
										
										
											2024-08-23 21:02:23 +08:00
										 |  |  |         if collection_name: | 
					
						
							|  |  |  |             try: | 
					
						
							|  |  |  |                 result = query_doc( | 
					
						
							|  |  |  |                     collection_name=collection_name, | 
					
						
							|  |  |  |                     k=k, | 
					
						
							| 
									
										
										
										
											2024-09-30 22:18:02 +08:00
										 |  |  |                     query_embedding=query_embedding, | 
					
						
							| 
									
										
										
										
											2024-08-23 21:02:23 +08:00
										 |  |  |                 ) | 
					
						
							| 
									
										
										
										
											2024-09-13 13:18:20 +08:00
										 |  |  |                 results.append(result.model_dump()) | 
					
						
							| 
									
										
										
										
											2024-09-12 21:34:52 +08:00
										 |  |  |             except Exception as e: | 
					
						
							|  |  |  |                 log.exception(f"Error when querying the collection: {e}") | 
					
						
							| 
									
										
										
										
											2024-08-23 21:02:23 +08:00
										 |  |  |         else: | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  |             pass | 
					
						
							| 
									
										
										
										
											2024-08-23 21:02:23 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  |     return merge_and_sort_query_results(results, k=k) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def query_collection_with_hybrid_search( | 
					
						
							| 
									
										
										
										
											2024-08-14 20:46:31 +08:00
										 |  |  |     collection_names: list[str], | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  |     query: str, | 
					
						
							|  |  |  |     embedding_function, | 
					
						
							| 
									
										
										
										
											2024-04-23 04:49:58 +08:00
										 |  |  |     k: int, | 
					
						
							|  |  |  |     reranking_function, | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  |     r: float, | 
					
						
							| 
									
										
										
										
											2024-09-12 21:50:18 +08:00
										 |  |  | ) -> dict: | 
					
						
							| 
									
										
										
										
											2024-04-15 05:55:00 +08:00
										 |  |  |     results = [] | 
					
						
							| 
									
										
										
										
											2024-09-13 13:18:20 +08:00
										 |  |  |     error = False | 
					
						
							| 
									
										
										
										
											2024-04-15 05:55:00 +08:00
										 |  |  |     for collection_name in collection_names: | 
					
						
							|  |  |  |         try: | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  |             result = query_doc_with_hybrid_search( | 
					
						
							| 
									
										
										
										
											2024-04-23 02:27:43 +08:00
										 |  |  |                 collection_name=collection_name, | 
					
						
							|  |  |  |                 query=query, | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  |                 embedding_function=embedding_function, | 
					
						
							| 
									
										
										
										
											2024-04-23 02:27:43 +08:00
										 |  |  |                 k=k, | 
					
						
							| 
									
										
										
										
											2024-04-23 04:49:58 +08:00
										 |  |  |                 reranking_function=reranking_function, | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  |                 r=r, | 
					
						
							| 
									
										
										
										
											2024-04-15 05:55:00 +08:00
										 |  |  |             ) | 
					
						
							|  |  |  |             results.append(result) | 
					
						
							| 
									
										
										
										
											2024-09-12 21:34:52 +08:00
										 |  |  |         except Exception as e: | 
					
						
							|  |  |  |             log.exception( | 
					
						
							| 
									
										
										
										
											2024-09-13 12:48:54 +08:00
										 |  |  |                 "Error when querying the collection with " f"hybrid_search: {e}" | 
					
						
							| 
									
										
										
										
											2024-09-12 21:34:52 +08:00
										 |  |  |             ) | 
					
						
							| 
									
										
										
										
											2024-09-13 13:18:20 +08:00
										 |  |  |             error = True | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if error: | 
					
						
							| 
									
										
										
										
											2024-09-13 12:48:54 +08:00
										 |  |  |         raise Exception( | 
					
						
							| 
									
										
										
										
											2024-09-16 17:46:39 +08:00
										 |  |  |             "Hybrid search failed for all collections. Using Non hybrid search as fallback." | 
					
						
							| 
									
										
										
										
											2024-09-13 12:48:54 +08:00
										 |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-09-13 13:18:20 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  |     return merge_and_sort_query_results(results, k=k, reverse=True) | 
					
						
							| 
									
										
										
										
											2024-04-15 05:55:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-09 14:34:47 +08:00
										 |  |  | def rag_template(template: str, context: str, query: str): | 
					
						
							| 
									
										
										
										
											2024-09-12 21:19:24 +08:00
										 |  |  |     count = template.count("[context]") | 
					
						
							|  |  |  |     assert "[context]" in template, "RAG template does not contain '[context]'" | 
					
						
							| 
									
										
										
										
											2024-09-13 13:06:51 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-12 22:04:41 +08:00
										 |  |  |     if "<context>" in context and "</context>" in context: | 
					
						
							|  |  |  |         log.debug( | 
					
						
							|  |  |  |             "WARNING: Potential prompt injection attack: the RAG " | 
					
						
							|  |  |  |             "context contains '<context>' and '</context>'. This might be " | 
					
						
							|  |  |  |             "nothing, or the user might be trying to hack something." | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-09-12 21:19:40 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     if "[query]" in context: | 
					
						
							| 
									
										
										
										
											2024-09-13 13:08:02 +08:00
										 |  |  |         query_placeholder = f"[query-{str(uuid.uuid4())}]" | 
					
						
							|  |  |  |         template = template.replace("[query]", query_placeholder) | 
					
						
							| 
									
										
										
										
											2024-09-12 21:19:40 +08:00
										 |  |  |         template = template.replace("[context]", context) | 
					
						
							|  |  |  |         template = template.replace(query_placeholder, query) | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         template = template.replace("[context]", context) | 
					
						
							|  |  |  |         template = template.replace("[query]", query) | 
					
						
							| 
									
										
										
										
											2024-03-09 14:34:47 +08:00
										 |  |  |     return template | 
					
						
							| 
									
										
										
										
											2024-03-11 09:40:50 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  | def get_embedding_function( | 
					
						
							| 
									
										
										
										
											2024-04-23 04:49:58 +08:00
										 |  |  |     embedding_engine, | 
					
						
							|  |  |  |     embedding_model, | 
					
						
							|  |  |  |     embedding_function, | 
					
						
							|  |  |  |     openai_key, | 
					
						
							|  |  |  |     openai_url, | 
					
						
							| 
									
										
										
										
											2024-06-02 22:34:31 +08:00
										 |  |  |     batch_size, | 
					
						
							| 
									
										
										
										
											2024-04-23 04:49:58 +08:00
										 |  |  | ): | 
					
						
							|  |  |  |     if embedding_engine == "": | 
					
						
							|  |  |  |         return lambda query: embedding_function.encode(query).tolist() | 
					
						
							| 
									
										
										
										
											2024-04-23 07:36:46 +08:00
										 |  |  |     elif embedding_engine in ["ollama", "openai"]: | 
					
						
							|  |  |  |         if embedding_engine == "ollama": | 
					
						
							|  |  |  |             func = lambda query: generate_ollama_embeddings( | 
					
						
							|  |  |  |                 GenerateEmbeddingsForm( | 
					
						
							|  |  |  |                     **{ | 
					
						
							|  |  |  |                         "model": embedding_model, | 
					
						
							|  |  |  |                         "prompt": query, | 
					
						
							|  |  |  |                     } | 
					
						
							|  |  |  |                 ) | 
					
						
							| 
									
										
										
										
											2024-04-23 04:49:58 +08:00
										 |  |  |             ) | 
					
						
							| 
									
										
										
										
											2024-04-23 07:36:46 +08:00
										 |  |  |         elif embedding_engine == "openai": | 
					
						
							|  |  |  |             func = lambda query: generate_openai_embeddings( | 
					
						
							|  |  |  |                 model=embedding_model, | 
					
						
							|  |  |  |                 text=query, | 
					
						
							|  |  |  |                 key=openai_key, | 
					
						
							|  |  |  |                 url=openai_url, | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         def generate_multiple(query, f): | 
					
						
							|  |  |  |             if isinstance(query, list): | 
					
						
							| 
									
										
										
										
											2024-06-02 22:34:31 +08:00
										 |  |  |                 if embedding_engine == "openai": | 
					
						
							|  |  |  |                     embeddings = [] | 
					
						
							|  |  |  |                     for i in range(0, len(query), batch_size): | 
					
						
							|  |  |  |                         embeddings.extend(f(query[i : i + batch_size])) | 
					
						
							|  |  |  |                     return embeddings | 
					
						
							|  |  |  |                 else: | 
					
						
							|  |  |  |                     return [f(q) for q in query] | 
					
						
							| 
									
										
										
										
											2024-04-23 07:36:46 +08:00
										 |  |  |             else: | 
					
						
							|  |  |  |                 return f(query) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return lambda query: generate_multiple(query, func) | 
					
						
							| 
									
										
										
										
											2024-04-23 04:49:58 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-11 16:10:24 +08:00
										 |  |  | def get_rag_context( | 
					
						
							| 
									
										
										
										
											2024-06-19 05:55:18 +08:00
										 |  |  |     files, | 
					
						
							| 
									
										
										
										
											2024-04-15 07:48:15 +08:00
										 |  |  |     messages, | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  |     embedding_function, | 
					
						
							| 
									
										
										
										
											2024-04-15 07:48:15 +08:00
										 |  |  |     k, | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  |     reranking_function, | 
					
						
							| 
									
										
										
										
											2024-04-23 07:36:46 +08:00
										 |  |  |     r, | 
					
						
							| 
									
										
										
										
											2024-04-27 02:41:39 +08:00
										 |  |  |     hybrid_search, | 
					
						
							| 
									
										
										
										
											2024-04-15 07:48:15 +08:00
										 |  |  | ): | 
					
						
							| 
									
										
										
										
											2024-06-19 05:55:18 +08:00
										 |  |  |     log.debug(f"files: {files} {messages} {embedding_function} {reranking_function}") | 
					
						
							| 
									
										
										
										
											2024-06-09 18:01:25 +08:00
										 |  |  |     query = get_last_user_message(messages) | 
					
						
							| 
									
										
										
										
											2024-03-11 09:40:50 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-23 07:36:46 +08:00
										 |  |  |     extracted_collections = [] | 
					
						
							| 
									
										
										
										
											2024-03-11 09:40:50 +08:00
										 |  |  |     relevant_contexts = [] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-19 05:55:18 +08:00
										 |  |  |     for file in files: | 
					
						
							| 
									
										
										
										
											2024-09-30 04:52:27 +08:00
										 |  |  |         if file.get("context") == "full": | 
					
						
							|  |  |  |             context = { | 
					
						
							| 
									
										
										
										
											2024-10-04 13:22:22 +08:00
										 |  |  |                 "documents": [[file.get("file").get("data", {}).get("content")]], | 
					
						
							| 
									
										
										
										
											2024-09-30 04:55:53 +08:00
										 |  |  |                 "metadatas": [[{"file_id": file.get("id"), "name": file.get("name")}]], | 
					
						
							| 
									
										
										
										
											2024-09-30 04:52:27 +08:00
										 |  |  |             } | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             context = None | 
					
						
							| 
									
										
										
										
											2024-03-11 09:40:50 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-04 14:06:47 +08:00
										 |  |  |             collection_names = [] | 
					
						
							|  |  |  |             if file.get("type") == "collection": | 
					
						
							|  |  |  |                 if file.get("legacy"): | 
					
						
							|  |  |  |                     collection_names = file.get("collection_names", []) | 
					
						
							|  |  |  |                 else: | 
					
						
							|  |  |  |                     collection_names.append(file["id"]) | 
					
						
							|  |  |  |             elif file.get("collection_name"): | 
					
						
							|  |  |  |                 collection_names.append(file["collection_name"]) | 
					
						
							|  |  |  |             elif file.get("id"): | 
					
						
							| 
									
										
										
										
											2024-10-04 15:59:19 +08:00
										 |  |  |                 if file.get("legacy"): | 
					
						
							|  |  |  |                     collection_names.append(f"{file['id']}") | 
					
						
							|  |  |  |                 else: | 
					
						
							|  |  |  |                     collection_names.append(f"file-{file['id']}") | 
					
						
							| 
									
										
										
										
											2024-05-07 06:49:00 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-30 04:52:27 +08:00
										 |  |  |             collection_names = set(collection_names).difference(extracted_collections) | 
					
						
							|  |  |  |             if not collection_names: | 
					
						
							|  |  |  |                 log.debug(f"skipping {file} as it has already been extracted") | 
					
						
							|  |  |  |                 continue | 
					
						
							| 
									
										
										
										
											2024-04-15 07:48:15 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-30 04:52:27 +08:00
										 |  |  |             try: | 
					
						
							|  |  |  |                 context = None | 
					
						
							| 
									
										
										
										
											2024-10-05 10:32:33 +08:00
										 |  |  |                 if file.get("type") == "text": | 
					
						
							| 
									
										
										
										
											2024-09-30 04:52:27 +08:00
										 |  |  |                     context = file["content"] | 
					
						
							|  |  |  |                 else: | 
					
						
							|  |  |  |                     if hybrid_search: | 
					
						
							|  |  |  |                         try: | 
					
						
							|  |  |  |                             context = query_collection_with_hybrid_search( | 
					
						
							|  |  |  |                                 collection_names=collection_names, | 
					
						
							|  |  |  |                                 query=query, | 
					
						
							|  |  |  |                                 embedding_function=embedding_function, | 
					
						
							|  |  |  |                                 k=k, | 
					
						
							|  |  |  |                                 reranking_function=reranking_function, | 
					
						
							|  |  |  |                                 r=r, | 
					
						
							|  |  |  |                             ) | 
					
						
							|  |  |  |                         except Exception as e: | 
					
						
							|  |  |  |                             log.debug( | 
					
						
							|  |  |  |                                 "Error when using hybrid search, using" | 
					
						
							|  |  |  |                                 " non hybrid search as fallback." | 
					
						
							|  |  |  |                             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     if (not hybrid_search) or (context is None): | 
					
						
							|  |  |  |                         context = query_collection( | 
					
						
							| 
									
										
										
										
											2024-09-12 21:58:26 +08:00
										 |  |  |                             collection_names=collection_names, | 
					
						
							|  |  |  |                             query=query, | 
					
						
							|  |  |  |                             embedding_function=embedding_function, | 
					
						
							|  |  |  |                             k=k, | 
					
						
							| 
									
										
										
										
											2024-09-13 12:48:54 +08:00
										 |  |  |                         ) | 
					
						
							| 
									
										
										
										
											2024-09-30 04:52:27 +08:00
										 |  |  |             except Exception as e: | 
					
						
							|  |  |  |                 log.exception(e) | 
					
						
							| 
									
										
										
										
											2024-09-12 21:58:26 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-30 04:52:27 +08:00
										 |  |  |             extracted_collections.extend(collection_names) | 
					
						
							| 
									
										
										
										
											2024-03-11 09:40:50 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-23 07:36:46 +08:00
										 |  |  |         if context: | 
					
						
							| 
									
										
										
										
											2024-09-30 04:52:27 +08:00
										 |  |  |             relevant_contexts.append({**context, "file": file}) | 
					
						
							| 
									
										
										
										
											2024-03-11 09:40:50 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-02 10:33:58 +08:00
										 |  |  |     contexts = [] | 
					
						
							| 
									
										
										
										
											2024-05-06 21:14:51 +08:00
										 |  |  |     citations = [] | 
					
						
							| 
									
										
										
										
											2024-03-11 09:40:50 +08:00
										 |  |  |     for context in relevant_contexts: | 
					
						
							| 
									
										
										
										
											2024-04-30 13:51:30 +08:00
										 |  |  |         try: | 
					
						
							|  |  |  |             if "documents" in context: | 
					
						
							| 
									
										
										
										
											2024-07-02 10:33:58 +08:00
										 |  |  |                 contexts.append( | 
					
						
							|  |  |  |                     "\n\n".join( | 
					
						
							|  |  |  |                         [text for text in context["documents"][0] if text is not None] | 
					
						
							|  |  |  |                     ) | 
					
						
							| 
									
										
										
										
											2024-05-07 06:49:00 +08:00
										 |  |  |                 ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-05-06 21:14:51 +08:00
										 |  |  |                 if "metadatas" in context: | 
					
						
							|  |  |  |                     citations.append( | 
					
						
							|  |  |  |                         { | 
					
						
							| 
									
										
										
										
											2024-09-30 04:52:27 +08:00
										 |  |  |                             "source": context["file"], | 
					
						
							| 
									
										
										
										
											2024-05-06 21:14:51 +08:00
										 |  |  |                             "document": context["documents"][0], | 
					
						
							|  |  |  |                             "metadata": context["metadatas"][0], | 
					
						
							|  |  |  |                         } | 
					
						
							|  |  |  |                     ) | 
					
						
							| 
									
										
										
										
											2024-04-30 13:51:30 +08:00
										 |  |  |         except Exception as e: | 
					
						
							|  |  |  |             log.exception(e) | 
					
						
							| 
									
										
										
										
											2024-05-07 06:14:33 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-02 10:33:58 +08:00
										 |  |  |     return contexts, citations | 
					
						
							| 
									
										
										
										
											2024-04-05 01:01:23 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-05 02:07:42 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-25 20:49:59 +08:00
										 |  |  | def get_model_path(model: str, update_model: bool = False): | 
					
						
							|  |  |  |     # Construct huggingface_hub kwargs with local_files_only to return the snapshot path | 
					
						
							|  |  |  |     cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     local_files_only = not update_model | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     snapshot_kwargs = { | 
					
						
							|  |  |  |         "cache_dir": cache_dir, | 
					
						
							|  |  |  |         "local_files_only": local_files_only, | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-26 02:28:31 +08:00
										 |  |  |     log.debug(f"model: {model}") | 
					
						
							| 
									
										
										
										
											2024-04-25 20:49:59 +08:00
										 |  |  |     log.debug(f"snapshot_kwargs: {snapshot_kwargs}") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Inspiration from upstream sentence_transformers | 
					
						
							|  |  |  |     if ( | 
					
						
							| 
									
										
										
										
											2024-07-15 17:09:05 +08:00
										 |  |  |         os.path.exists(model) | 
					
						
							| 
									
										
										
										
											2024-04-25 20:49:59 +08:00
										 |  |  |         or ("\\" in model or model.count("/") > 1) | 
					
						
							|  |  |  |         and local_files_only | 
					
						
							|  |  |  |     ): | 
					
						
							|  |  |  |         # If fully qualified path exists, return input, else set repo_id | 
					
						
							| 
									
										
										
										
											2024-07-15 17:09:05 +08:00
										 |  |  |         return model | 
					
						
							| 
									
										
										
										
											2024-04-25 20:49:59 +08:00
										 |  |  |     elif "/" not in model: | 
					
						
							|  |  |  |         # Set valid repo_id for model short-name | 
					
						
							|  |  |  |         model = "sentence-transformers" + "/" + model | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     snapshot_kwargs["repo_id"] = model | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Attempt to query the huggingface_hub library to determine the local path and/or to update | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         model_repo_path = snapshot_download(**snapshot_kwargs) | 
					
						
							|  |  |  |         log.debug(f"model_repo_path: {model_repo_path}") | 
					
						
							|  |  |  |         return model_repo_path | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							|  |  |  |         log.exception(f"Cannot determine model snapshot path: {e}") | 
					
						
							| 
									
										
										
										
											2024-07-15 17:09:05 +08:00
										 |  |  |         return model | 
					
						
							| 
									
										
										
										
											2024-04-25 20:49:59 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-15 07:15:39 +08:00
										 |  |  | def generate_openai_embeddings( | 
					
						
							| 
									
										
										
										
											2024-06-02 22:34:31 +08:00
										 |  |  |     model: str, | 
					
						
							|  |  |  |     text: Union[str, list[str]], | 
					
						
							|  |  |  |     key: str, | 
					
						
							|  |  |  |     url: str = "https://api.openai.com/v1", | 
					
						
							| 
									
										
										
										
											2024-04-15 07:15:39 +08:00
										 |  |  | ): | 
					
						
							| 
									
										
										
										
											2024-06-02 22:34:31 +08:00
										 |  |  |     if isinstance(text, list): | 
					
						
							|  |  |  |         embeddings = generate_openai_batch_embeddings(model, text, key, url) | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         embeddings = generate_openai_batch_embeddings(model, [text], key, url) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return embeddings[0] if isinstance(text, str) else embeddings | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def generate_openai_batch_embeddings( | 
					
						
							|  |  |  |     model: str, texts: list[str], key: str, url: str = "https://api.openai.com/v1" | 
					
						
							|  |  |  | ) -> Optional[list[list[float]]]: | 
					
						
							| 
									
										
										
										
											2024-04-15 07:15:39 +08:00
										 |  |  |     try: | 
					
						
							|  |  |  |         r = requests.post( | 
					
						
							| 
									
										
										
										
											2024-04-21 04:15:59 +08:00
										 |  |  |             f"{url}/embeddings", | 
					
						
							| 
									
										
										
										
											2024-04-15 07:15:39 +08:00
										 |  |  |             headers={ | 
					
						
							|  |  |  |                 "Content-Type": "application/json", | 
					
						
							|  |  |  |                 "Authorization": f"Bearer {key}", | 
					
						
							|  |  |  |             }, | 
					
						
							| 
									
										
										
										
											2024-06-02 22:34:31 +08:00
										 |  |  |             json={"input": texts, "model": model}, | 
					
						
							| 
									
										
										
										
											2024-04-15 07:15:39 +08:00
										 |  |  |         ) | 
					
						
							|  |  |  |         r.raise_for_status() | 
					
						
							|  |  |  |         data = r.json() | 
					
						
							|  |  |  |         if "data" in data: | 
					
						
							| 
									
										
										
										
											2024-06-02 22:34:31 +08:00
										 |  |  |             return [elem["embedding"] for elem in data["data"]] | 
					
						
							| 
									
										
										
										
											2024-04-15 07:15:39 +08:00
										 |  |  |         else: | 
					
						
							|  |  |  |             raise "Something went wrong :/" | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							|  |  |  |         print(e) | 
					
						
							|  |  |  |         return None | 
					
						
							| 
									
										
										
										
											2024-04-23 04:49:58 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-23 07:36:46 +08:00
										 |  |  | import operator | 
					
						
							|  |  |  | from typing import Optional, Sequence | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from langchain_core.callbacks import Callbacks | 
					
						
							| 
									
										
										
										
											2024-08-28 06:10:27 +08:00
										 |  |  | from langchain_core.documents import BaseDocumentCompressor, Document | 
					
						
							| 
									
										
										
										
											2024-04-23 07:36:46 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class RerankCompressor(BaseDocumentCompressor): | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  |     embedding_function: Any | 
					
						
							| 
									
										
										
										
											2024-04-30 01:15:58 +08:00
										 |  |  |     top_n: int | 
					
						
							| 
									
										
										
										
											2024-04-23 07:36:46 +08:00
										 |  |  |     reranking_function: Any | 
					
						
							|  |  |  |     r_score: float | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     class Config: | 
					
						
							| 
									
										
										
										
											2024-09-19 23:05:49 +08:00
										 |  |  |         extra = "forbid" | 
					
						
							| 
									
										
										
										
											2024-04-23 07:36:46 +08:00
										 |  |  |         arbitrary_types_allowed = True | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def compress_documents( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         documents: Sequence[Document], | 
					
						
							|  |  |  |         query: str, | 
					
						
							|  |  |  |         callbacks: Optional[Callbacks] = None, | 
					
						
							|  |  |  |     ) -> Sequence[Document]: | 
					
						
							| 
									
										
										
										
											2024-04-30 01:15:58 +08:00
										 |  |  |         reranking = self.reranking_function is not None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if reranking: | 
					
						
							| 
									
										
										
										
											2024-04-23 07:36:46 +08:00
										 |  |  |             scores = self.reranking_function.predict( | 
					
						
							|  |  |  |                 [(query, doc.page_content) for doc in documents] | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |         else: | 
					
						
							| 
									
										
										
										
											2024-07-01 08:13:56 +08:00
										 |  |  |             from sentence_transformers import util | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-28 03:38:50 +08:00
										 |  |  |             query_embedding = self.embedding_function(query) | 
					
						
							|  |  |  |             document_embedding = self.embedding_function( | 
					
						
							| 
									
										
										
										
											2024-04-23 07:36:46 +08:00
										 |  |  |                 [doc.page_content for doc in documents] | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             scores = util.cos_sim(query_embedding, document_embedding)[0] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         docs_with_scores = list(zip(documents, scores.tolist())) | 
					
						
							|  |  |  |         if self.r_score: | 
					
						
							|  |  |  |             docs_with_scores = [ | 
					
						
							|  |  |  |                 (d, s) for d, s in docs_with_scores if s >= self.r_score | 
					
						
							|  |  |  |             ] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-30 01:15:58 +08:00
										 |  |  |         result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True) | 
					
						
							| 
									
										
										
										
											2024-04-23 07:36:46 +08:00
										 |  |  |         final_results = [] | 
					
						
							|  |  |  |         for doc, doc_score in result[: self.top_n]: | 
					
						
							|  |  |  |             metadata = doc.metadata | 
					
						
							|  |  |  |             metadata["score"] = doc_score | 
					
						
							|  |  |  |             doc = Document( | 
					
						
							|  |  |  |                 page_content=doc.page_content, | 
					
						
							|  |  |  |                 metadata=metadata, | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             final_results.append(doc) | 
					
						
							|  |  |  |         return final_results |