317 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			317 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
| from elasticsearch import Elasticsearch, BadRequestError
 | |
| from typing import Optional
 | |
| import ssl
 | |
| from elasticsearch.helpers import bulk,scan
 | |
| from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
 | |
| from open_webui.config import (
 | |
|     ELASTICSEARCH_URL,
 | |
|     ELASTICSEARCH_CA_CERTS, 
 | |
|     ELASTICSEARCH_API_KEY,
 | |
|     ELASTICSEARCH_USERNAME,
 | |
|     ELASTICSEARCH_PASSWORD, 
 | |
|     ELASTICSEARCH_CLOUD_ID,
 | |
|     ELASTICSEARCH_INDEX_PREFIX,
 | |
|     SSL_ASSERT_FINGERPRINT,
 | |
|     
 | |
| )
 | |
| 
 | |
| 
 | |
| 
 | |
| 
 | |
| class ElasticsearchClient:
 | |
|     """
 | |
|     Important:
 | |
|     in order to reduce the number of indexes and since the embedding vector length is fixed, we avoid creating 
 | |
|     an index for each file but store it as a text field, while seperating to different index 
 | |
|     baesd on the embedding length.
 | |
|     """
 | |
|     def __init__(self):
 | |
|         self.index_prefix = ELASTICSEARCH_INDEX_PREFIX
 | |
|         self.client = Elasticsearch(
 | |
|             hosts=[ELASTICSEARCH_URL],
 | |
|             ca_certs=ELASTICSEARCH_CA_CERTS,
 | |
|             api_key=ELASTICSEARCH_API_KEY,
 | |
|             cloud_id=ELASTICSEARCH_CLOUD_ID,
 | |
|             basic_auth=(ELASTICSEARCH_USERNAME,ELASTICSEARCH_PASSWORD) if ELASTICSEARCH_USERNAME and ELASTICSEARCH_PASSWORD else None,
 | |
|             ssl_assert_fingerprint=SSL_ASSERT_FINGERPRINT
 | |
|             
 | |
|         )
 | |
|     #Status: works
 | |
|     def _get_index_name(self,dimension:int)->str:
 | |
|         return f"{self.index_prefix}_d{str(dimension)}"
 | |
|     
 | |
|     #Status: works
 | |
|     def _scan_result_to_get_result(self, result) -> GetResult:
 | |
|         if not result:
 | |
|             return None
 | |
|         ids = []
 | |
|         documents = []
 | |
|         metadatas = []
 | |
| 
 | |
|         for hit in result:
 | |
|             ids.append(hit["_id"])
 | |
|             documents.append(hit["_source"].get("text"))
 | |
|             metadatas.append(hit["_source"].get("metadata"))
 | |
| 
 | |
|         return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
 | |
| 
 | |
|     #Status: works
 | |
|     def _result_to_get_result(self, result) -> GetResult:
 | |
|         if not result["hits"]["hits"]:
 | |
|             return None
 | |
|         ids = []
 | |
|         documents = []
 | |
|         metadatas = []
 | |
| 
 | |
|         for hit in result["hits"]["hits"]:
 | |
|             ids.append(hit["_id"])
 | |
|             documents.append(hit["_source"].get("text"))
 | |
|             metadatas.append(hit["_source"].get("metadata"))
 | |
| 
 | |
|         return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
 | |
| 
 | |
|     #Status: works
 | |
|     def _result_to_search_result(self, result) -> SearchResult:
 | |
|         ids = []
 | |
|         distances = []
 | |
|         documents = []
 | |
|         metadatas = []
 | |
| 
 | |
|         for hit in result["hits"]["hits"]:
 | |
|             ids.append(hit["_id"])
 | |
|             distances.append(hit["_score"])
 | |
|             documents.append(hit["_source"].get("text"))
 | |
|             metadatas.append(hit["_source"].get("metadata"))
 | |
| 
 | |
|         return SearchResult(
 | |
|             ids=[ids], distances=[distances], documents=[documents], metadatas=[metadatas]
 | |
|         )
 | |
|     #Status: works
 | |
|     def _create_index(self, dimension: int):
 | |
|         body = {
 | |
|             "mappings": {
 | |
|                 "dynamic_templates": [
 | |
|                     {
 | |
|                     "strings": {
 | |
|                     "match_mapping_type": "string",
 | |
|                     "mapping": {
 | |
|                         "type": "keyword"
 | |
|                          }
 | |
|                         }
 | |
|                     }
 | |
|                 ],
 | |
|                 "properties": {
 | |
|                     "collection": {"type": "keyword"},
 | |
|                     "id": {"type": "keyword"},
 | |
|                     "vector": {
 | |
|                         "type": "dense_vector",
 | |
|                         "dims": dimension,  # Adjust based on your vector dimensions
 | |
|                         "index": True,
 | |
|                         "similarity": "cosine",
 | |
|                     },
 | |
|                     "text": {"type": "text"},
 | |
|                     "metadata": {"type": "object"},
 | |
|                 }
 | |
|             }
 | |
|         }
 | |
|         self.client.indices.create(index=self._get_index_name(dimension), body=body)
 | |
|     #Status: works
 | |
| 
 | |
|     def _create_batches(self, items: list[VectorItem], batch_size=100):
 | |
|         for i in range(0, len(items), batch_size):
 | |
|             yield items[i : min(i + batch_size,len(items))]
 | |
| 
 | |
|     #Status: works
 | |
|     def has_collection(self,collection_name) -> bool:
 | |
|         query_body = {"query": {"bool": {"filter": []}}}
 | |
|         query_body["query"]["bool"]["filter"].append({"term": {"collection": collection_name}})
 | |
| 
 | |
|         try:
 | |
|             result = self.client.count(
 | |
|                 index=f"{self.index_prefix}*",
 | |
|                 body=query_body
 | |
|             )
 | |
|             
 | |
|             return result.body["count"]>0
 | |
|         except Exception as e:
 | |
|             return None
 | |
|         
 | |
| 
 | |
|         
 | |
|     def delete_collection(self, collection_name: str):
 | |
|         query = {
 | |
|             "query": {
 | |
|                 "term": {"collection": collection_name}
 | |
|             }
 | |
|         }
 | |
|         self.client.delete_by_query(index=f"{self.index_prefix}*", body=query)
 | |
|     #Status: works
 | |
|     def search(
 | |
|         self, collection_name: str, vectors: list[list[float]], limit: int
 | |
|     ) -> Optional[SearchResult]:
 | |
|         query = {
 | |
|             "size": limit,
 | |
|             "_source": [
 | |
|                 "text",
 | |
|                 "metadata"
 | |
|             ],
 | |
|             "query": {
 | |
|                 "script_score": {
 | |
|                     "query": {
 | |
|                         "bool": {
 | |
|                             "filter": [
 | |
|                                 {
 | |
|                                     "term": {
 | |
|                                         "collection": collection_name
 | |
|                                     }
 | |
|                                 }
 | |
|                             ]
 | |
|                         }
 | |
|                     },
 | |
|                     "script": {
 | |
|                         "source": "cosineSimilarity(params.vector, 'vector') + 1.0",
 | |
|                         "params": {
 | |
|                             "vector": vectors[0]
 | |
|                         }, # Assuming single query vector
 | |
|                     },
 | |
|                 }
 | |
|             },
 | |
|         }
 | |
| 
 | |
|         result = self.client.search(
 | |
|             index=self._get_index_name(len(vectors[0])), body=query
 | |
|         )
 | |
| 
 | |
|         return self._result_to_search_result(result)
 | |
|     #Status: only tested halfwat
 | |
|     def query(
 | |
|         self, collection_name: str, filter: dict, limit: Optional[int] = None
 | |
|     ) -> Optional[GetResult]:
 | |
|         if not self.has_collection(collection_name):
 | |
|             return None
 | |
| 
 | |
|         query_body = {
 | |
|             "query": {"bool": {"filter": []}},
 | |
|             "_source": ["text", "metadata"],
 | |
|         }
 | |
| 
 | |
|         for field, value in filter.items():
 | |
|             query_body["query"]["bool"]["filter"].append({"term": {field: value}})
 | |
|         query_body["query"]["bool"]["filter"].append({"term": {"collection": collection_name}})
 | |
|         size = limit if limit else 10
 | |
| 
 | |
|         try:
 | |
|             result = self.client.search(
 | |
|                 index=f"{self.index_prefix}*",
 | |
|                 body=query_body,
 | |
|                 size=size,
 | |
|             )
 | |
|             
 | |
|             return self._result_to_get_result(result)
 | |
| 
 | |
|         except Exception as e:
 | |
|             return None
 | |
|     #Status: works
 | |
|     def _has_index(self,dimension:int):
 | |
|         return self.client.indices.exists(index=self._get_index_name(dimension=dimension))
 | |
| 
 | |
| 
 | |
|     def get_or_create_index(self, dimension: int):
 | |
|         if not self._has_index(dimension=dimension):
 | |
|             self._create_index(dimension=dimension)
 | |
|     #Status: works
 | |
|     def get(self, collection_name: str) -> Optional[GetResult]:
 | |
|         # Get all the items in the collection.
 | |
|         query = {
 | |
|                     "query": {
 | |
|                         "bool": {
 | |
|                             "filter": [
 | |
|                                 {
 | |
|                                     "term": {
 | |
|                                         "collection": collection_name
 | |
|                                     }
 | |
|                                 }
 | |
|                             ]
 | |
|                         }
 | |
|                     }, "_source": ["text", "metadata"]}
 | |
|         results = list(scan(self.client, index=f"{self.index_prefix}*", query=query))
 | |
|         
 | |
|         return self._scan_result_to_get_result(results)
 | |
| 
 | |
|     #Status: works
 | |
|     def insert(self, collection_name: str, items: list[VectorItem]):
 | |
|         if not self._has_index(dimension=len(items[0]["vector"])):
 | |
|             self._create_index(dimension=len(items[0]["vector"]))
 | |
| 
 | |
| 
 | |
|         for batch in self._create_batches(items):
 | |
|             actions = [
 | |
|                     {
 | |
|                         "_index":self._get_index_name(dimension=len(items[0]["vector"])),
 | |
|                         "_id": item["id"],
 | |
|                         "_source": {
 | |
|                             "collection": collection_name,
 | |
|                             "vector": item["vector"],
 | |
|                             "text": item["text"],
 | |
|                             "metadata": item["metadata"],
 | |
|                         },
 | |
|                     }
 | |
|                 for item in batch
 | |
|             ]
 | |
|             bulk(self.client,actions)
 | |
| 
 | |
|     # Upsert documents using the update API with doc_as_upsert=True.
 | |
|     def upsert(self, collection_name: str, items: list[VectorItem]):
 | |
|         if not self._has_index(dimension=len(items[0]["vector"])):
 | |
|             self._create_index(dimension=len(items[0]["vector"]))
 | |
|         for batch in self._create_batches(items):
 | |
|             actions = [
 | |
|                 {
 | |
|                     "_op_type": "update",
 | |
|                     "_index": self._get_index_name(dimension=len(item["vector"])),
 | |
|                     "_id": item["id"],
 | |
|                     "doc": {
 | |
|                         "collection": collection_name,
 | |
|                         "vector": item["vector"],
 | |
|                         "text": item["text"],
 | |
|                         "metadata": item["metadata"],
 | |
|                     },
 | |
|                     "doc_as_upsert": True,
 | |
|                 }
 | |
|                 for item in batch
 | |
|             ]
 | |
|             bulk(self.client,actions)
 | |
| 
 | |
| 
 | |
|     # Delete specific documents from a collection by filtering on both collection and document IDs.
 | |
|     def delete(
 | |
|         self,
 | |
|         collection_name: str,
 | |
|         ids: Optional[list[str]] = None,
 | |
|         filter: Optional[dict] = None,
 | |
|     ):
 | |
| 
 | |
|         query = {
 | |
|             "query": {
 | |
|                 "bool": {
 | |
|                     "filter": [
 | |
|                         {"term": {"collection": collection_name}}
 | |
|                     ]
 | |
|                 }
 | |
|             }
 | |
|         }
 | |
|                 #logic based on chromaDB
 | |
|         if ids:
 | |
|             query["query"]["bool"]["filter"].append({"terms": {"_id": ids}})
 | |
|         elif filter:
 | |
|             for field, value in filter.items():
 | |
|                 query["query"]["bool"]["filter"].append({"term": {f"metadata.{field}": value}})
 | |
| 
 | |
| 
 | |
|         self.client.delete_by_query(index=f"{self.index_prefix}*", body=query)
 | |
| 
 | |
|     def reset(self):
 | |
|         indices = self.client.indices.get(index=f"{self.index_prefix}*")
 | |
|         for index in indices:
 | |
|             self.client.indices.delete(index=index)
 |