Merge pull request #14370 from daw/feat/add-azure-openai-embeddings-option
feat:Add Azure OpenAI embedding support
This commit is contained in:
		
						commit
						ff353578db
					
				|  | @ -2184,6 +2184,27 @@ RAG_OPENAI_API_KEY = PersistentConfig( | ||||||
|     os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY), |     os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY), | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | RAG_AZURE_OPENAI_BASE_URL = PersistentConfig( | ||||||
|  |     "RAG_AZURE_OPENAI_BASE_URL", | ||||||
|  |     "rag.azure_openai.base_url", | ||||||
|  |     os.getenv("RAG_AZURE_OPENAI_BASE_URL", ""), | ||||||
|  | ) | ||||||
|  | RAG_AZURE_OPENAI_API_KEY = PersistentConfig( | ||||||
|  |     "RAG_AZURE_OPENAI_API_KEY", | ||||||
|  |     "rag.azure_openai.api_key", | ||||||
|  |     os.getenv("RAG_AZURE_OPENAI_API_KEY", ""), | ||||||
|  | ) | ||||||
|  | RAG_AZURE_OPENAI_DEPLOYMENT = PersistentConfig( | ||||||
|  |     "RAG_AZURE_OPENAI_DEPLOYMENT", | ||||||
|  |     "rag.azure_openai.deployment", | ||||||
|  |     os.getenv("RAG_AZURE_OPENAI_DEPLOYMENT", ""), | ||||||
|  | ) | ||||||
|  | RAG_AZURE_OPENAI_VERSION = PersistentConfig( | ||||||
|  |     "RAG_AZURE_OPENAI_VERSION", | ||||||
|  |     "rag.azure_openai.version", | ||||||
|  |     os.getenv("RAG_AZURE_OPENAI_VERSION", ""), | ||||||
|  | ) | ||||||
|  | 
 | ||||||
| RAG_OLLAMA_BASE_URL = PersistentConfig( | RAG_OLLAMA_BASE_URL = PersistentConfig( | ||||||
|     "RAG_OLLAMA_BASE_URL", |     "RAG_OLLAMA_BASE_URL", | ||||||
|     "rag.ollama.url", |     "rag.ollama.url", | ||||||
|  |  | ||||||
|  | @ -207,6 +207,10 @@ from open_webui.config import ( | ||||||
|     RAG_FILE_MAX_SIZE, |     RAG_FILE_MAX_SIZE, | ||||||
|     RAG_OPENAI_API_BASE_URL, |     RAG_OPENAI_API_BASE_URL, | ||||||
|     RAG_OPENAI_API_KEY, |     RAG_OPENAI_API_KEY, | ||||||
|  |     RAG_AZURE_OPENAI_BASE_URL, | ||||||
|  |     RAG_AZURE_OPENAI_API_KEY, | ||||||
|  |     RAG_AZURE_OPENAI_DEPLOYMENT, | ||||||
|  |     RAG_AZURE_OPENAI_VERSION, | ||||||
|     RAG_OLLAMA_BASE_URL, |     RAG_OLLAMA_BASE_URL, | ||||||
|     RAG_OLLAMA_API_KEY, |     RAG_OLLAMA_API_KEY, | ||||||
|     CHUNK_OVERLAP, |     CHUNK_OVERLAP, | ||||||
|  | @ -717,6 +721,11 @@ app.state.config.RAG_TEMPLATE = RAG_TEMPLATE | ||||||
| app.state.config.RAG_OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL | app.state.config.RAG_OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL | ||||||
| app.state.config.RAG_OPENAI_API_KEY = RAG_OPENAI_API_KEY | app.state.config.RAG_OPENAI_API_KEY = RAG_OPENAI_API_KEY | ||||||
| 
 | 
 | ||||||
|  | app.state.config.RAG_AZURE_OPENAI_BASE_URL = RAG_AZURE_OPENAI_BASE_URL | ||||||
|  | app.state.config.RAG_AZURE_OPENAI_API_KEY = RAG_AZURE_OPENAI_API_KEY | ||||||
|  | app.state.config.RAG_AZURE_OPENAI_DEPLOYMENT = RAG_AZURE_OPENAI_DEPLOYMENT | ||||||
|  | app.state.config.RAG_AZURE_OPENAI_VERSION = RAG_AZURE_OPENAI_VERSION | ||||||
|  | 
 | ||||||
| app.state.config.RAG_OLLAMA_BASE_URL = RAG_OLLAMA_BASE_URL | app.state.config.RAG_OLLAMA_BASE_URL = RAG_OLLAMA_BASE_URL | ||||||
| app.state.config.RAG_OLLAMA_API_KEY = RAG_OLLAMA_API_KEY | app.state.config.RAG_OLLAMA_API_KEY = RAG_OLLAMA_API_KEY | ||||||
| 
 | 
 | ||||||
|  | @ -811,14 +820,32 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function( | ||||||
|     ( |     ( | ||||||
|         app.state.config.RAG_OPENAI_API_BASE_URL |         app.state.config.RAG_OPENAI_API_BASE_URL | ||||||
|         if app.state.config.RAG_EMBEDDING_ENGINE == "openai" |         if app.state.config.RAG_EMBEDDING_ENGINE == "openai" | ||||||
|         else app.state.config.RAG_OLLAMA_BASE_URL |         else ( | ||||||
|  |             app.state.config.RAG_OLLAMA_BASE_URL | ||||||
|  |             if app.state.config.RAG_EMBEDDING_ENGINE == "ollama" | ||||||
|  |             else app.state.config.RAG_AZURE_OPENAI_BASE_URL | ||||||
|  |         ) | ||||||
|     ), |     ), | ||||||
|     ( |     ( | ||||||
|         app.state.config.RAG_OPENAI_API_KEY |         app.state.config.RAG_OPENAI_API_KEY | ||||||
|         if app.state.config.RAG_EMBEDDING_ENGINE == "openai" |         if app.state.config.RAG_EMBEDDING_ENGINE == "openai" | ||||||
|         else app.state.config.RAG_OLLAMA_API_KEY |         else ( | ||||||
|  |             app.state.config.RAG_OLLAMA_API_KEY | ||||||
|  |             if app.state.config.RAG_EMBEDDING_ENGINE == "ollama" | ||||||
|  |             else app.state.config.RAG_AZURE_OPENAI_API_KEY | ||||||
|  |         ) | ||||||
|     ), |     ), | ||||||
|     app.state.config.RAG_EMBEDDING_BATCH_SIZE, |     app.state.config.RAG_EMBEDDING_BATCH_SIZE, | ||||||
|  |     ( | ||||||
|  |         app.state.config.RAG_AZURE_OPENAI_DEPLOYMENT | ||||||
|  |         if app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai" | ||||||
|  |         else None | ||||||
|  |     ), | ||||||
|  |     ( | ||||||
|  |         app.state.config.RAG_AZURE_OPENAI_VERSION | ||||||
|  |         if app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai" | ||||||
|  |         else None | ||||||
|  |     ), | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| ######################################## | ######################################## | ||||||
|  |  | ||||||
|  | @ -5,6 +5,7 @@ from typing import Optional, Union | ||||||
| import requests | import requests | ||||||
| import hashlib | import hashlib | ||||||
| from concurrent.futures import ThreadPoolExecutor | from concurrent.futures import ThreadPoolExecutor | ||||||
|  | import time | ||||||
| 
 | 
 | ||||||
| from huggingface_hub import snapshot_download | from huggingface_hub import snapshot_download | ||||||
| from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever | from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever | ||||||
|  | @ -400,12 +401,14 @@ def get_embedding_function( | ||||||
|     url, |     url, | ||||||
|     key, |     key, | ||||||
|     embedding_batch_size, |     embedding_batch_size, | ||||||
|  |     deployment=None, | ||||||
|  |     version=None, | ||||||
| ): | ): | ||||||
|     if embedding_engine == "": |     if embedding_engine == "": | ||||||
|         return lambda query, prefix=None, user=None: embedding_function.encode( |         return lambda query, prefix=None, user=None: embedding_function.encode( | ||||||
|             query, **({"prompt": prefix} if prefix else {}) |             query, **({"prompt": prefix} if prefix else {}) | ||||||
|         ).tolist() |         ).tolist() | ||||||
|     elif embedding_engine in ["ollama", "openai"]: |     elif embedding_engine in ["ollama", "openai", "azure_openai"]: | ||||||
|         func = lambda query, prefix=None, user=None: generate_embeddings( |         func = lambda query, prefix=None, user=None: generate_embeddings( | ||||||
|             engine=embedding_engine, |             engine=embedding_engine, | ||||||
|             model=embedding_model, |             model=embedding_model, | ||||||
|  | @ -414,6 +417,8 @@ def get_embedding_function( | ||||||
|             url=url, |             url=url, | ||||||
|             key=key, |             key=key, | ||||||
|             user=user, |             user=user, | ||||||
|  |             deployment=deployment, | ||||||
|  |             version=version, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         def generate_multiple(query, prefix, user, func): |         def generate_multiple(query, prefix, user, func): | ||||||
|  | @ -697,6 +702,61 @@ def generate_openai_batch_embeddings( | ||||||
|         return None |         return None | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def generate_azure_openai_batch_embeddings( | ||||||
|  |     deployment: str, | ||||||
|  |     texts: list[str], | ||||||
|  |     url: str, | ||||||
|  |     key: str = "", | ||||||
|  |     model: str = "", | ||||||
|  |     version: str = "", | ||||||
|  |     prefix: str = None, | ||||||
|  |     user: UserModel = None, | ||||||
|  | ) -> Optional[list[list[float]]]: | ||||||
|  |     try: | ||||||
|  |         log.debug( | ||||||
|  |             f"generate_azure_openai_batch_embeddings:deployment {deployment} batch size: {len(texts)}" | ||||||
|  |         ) | ||||||
|  |         json_data = {"input": texts, "model": model} | ||||||
|  |         if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str): | ||||||
|  |             json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix | ||||||
|  | 
 | ||||||
|  |         url = f"{url}/openai/deployments/{deployment}/embeddings?api-version={version}" | ||||||
|  | 
 | ||||||
|  |         for _ in range(5): | ||||||
|  |             r = requests.post( | ||||||
|  |                 url, | ||||||
|  |                 headers={ | ||||||
|  |                     "Content-Type": "application/json", | ||||||
|  |                     "api-key": key, | ||||||
|  |                     **( | ||||||
|  |                         { | ||||||
|  |                             "X-OpenWebUI-User-Name": user.name, | ||||||
|  |                             "X-OpenWebUI-User-Id": user.id, | ||||||
|  |                             "X-OpenWebUI-User-Email": user.email, | ||||||
|  |                             "X-OpenWebUI-User-Role": user.role, | ||||||
|  |                         } | ||||||
|  |                         if ENABLE_FORWARD_USER_INFO_HEADERS and user | ||||||
|  |                         else {} | ||||||
|  |                     ), | ||||||
|  |                 }, | ||||||
|  |                 json=json_data, | ||||||
|  |             ) | ||||||
|  |             if r.status_code == 429: | ||||||
|  |                 retry = float(r.headers.get("Retry-After", "1")) | ||||||
|  |                 time.sleep(retry) | ||||||
|  |                 continue | ||||||
|  |             r.raise_for_status() | ||||||
|  |             data = r.json() | ||||||
|  |             if "data" in data: | ||||||
|  |                 return [elem["embedding"] for elem in data["data"]] | ||||||
|  |             else: | ||||||
|  |                 raise Exception("Something went wrong :/") | ||||||
|  |         return None | ||||||
|  |     except Exception as e: | ||||||
|  |         log.exception(f"Error generating azure openai batch embeddings: {e}") | ||||||
|  |         return None | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| def generate_ollama_batch_embeddings( | def generate_ollama_batch_embeddings( | ||||||
|     model: str, |     model: str, | ||||||
|     texts: list[str], |     texts: list[str], | ||||||
|  | @ -794,6 +854,32 @@ def generate_embeddings( | ||||||
|                 model, [text], url, key, prefix, user |                 model, [text], url, key, prefix, user | ||||||
|             ) |             ) | ||||||
|         return embeddings[0] if isinstance(text, str) else embeddings |         return embeddings[0] if isinstance(text, str) else embeddings | ||||||
|  |     elif engine == "azure_openai": | ||||||
|  |         deployment = kwargs.get("deployment", "") | ||||||
|  |         version = kwargs.get("version", "") | ||||||
|  |         if isinstance(text, list): | ||||||
|  |             embeddings = generate_azure_openai_batch_embeddings( | ||||||
|  |                 deployment, | ||||||
|  |                 text, | ||||||
|  |                 url, | ||||||
|  |                 key, | ||||||
|  |                 model, | ||||||
|  |                 version, | ||||||
|  |                 prefix, | ||||||
|  |                 user, | ||||||
|  |             ) | ||||||
|  |         else: | ||||||
|  |             embeddings = generate_azure_openai_batch_embeddings( | ||||||
|  |                 deployment, | ||||||
|  |                 [text], | ||||||
|  |                 url, | ||||||
|  |                 key, | ||||||
|  |                 model, | ||||||
|  |                 version, | ||||||
|  |                 prefix, | ||||||
|  |                 user, | ||||||
|  |             ) | ||||||
|  |         return embeddings[0] if isinstance(text, str) else embeddings | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| import operator | import operator | ||||||
|  |  | ||||||
|  | @ -239,6 +239,12 @@ async def get_embedding_config(request: Request, user=Depends(get_admin_user)): | ||||||
|             "url": request.app.state.config.RAG_OLLAMA_BASE_URL, |             "url": request.app.state.config.RAG_OLLAMA_BASE_URL, | ||||||
|             "key": request.app.state.config.RAG_OLLAMA_API_KEY, |             "key": request.app.state.config.RAG_OLLAMA_API_KEY, | ||||||
|         }, |         }, | ||||||
|  |         "azure_openai_config": { | ||||||
|  |             "url": request.app.state.config.RAG_AZURE_OPENAI_BASE_URL, | ||||||
|  |             "key": request.app.state.config.RAG_AZURE_OPENAI_API_KEY, | ||||||
|  |             "deployment": request.app.state.config.RAG_AZURE_OPENAI_DEPLOYMENT, | ||||||
|  |             "version": request.app.state.config.RAG_AZURE_OPENAI_VERSION, | ||||||
|  |         }, | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -252,9 +258,17 @@ class OllamaConfigForm(BaseModel): | ||||||
|     key: str |     key: str | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | class AzureOpenAIConfigForm(BaseModel): | ||||||
|  |     url: str | ||||||
|  |     key: str | ||||||
|  |     deployment: str | ||||||
|  |     version: str | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class EmbeddingModelUpdateForm(BaseModel): | class EmbeddingModelUpdateForm(BaseModel): | ||||||
|     openai_config: Optional[OpenAIConfigForm] = None |     openai_config: Optional[OpenAIConfigForm] = None | ||||||
|     ollama_config: Optional[OllamaConfigForm] = None |     ollama_config: Optional[OllamaConfigForm] = None | ||||||
|  |     azure_openai_config: Optional[AzureOpenAIConfigForm] = None | ||||||
|     embedding_engine: str |     embedding_engine: str | ||||||
|     embedding_model: str |     embedding_model: str | ||||||
|     embedding_batch_size: Optional[int] = 1 |     embedding_batch_size: Optional[int] = 1 | ||||||
|  | @ -271,7 +285,7 @@ async def update_embedding_config( | ||||||
|         request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine |         request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine | ||||||
|         request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model |         request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model | ||||||
| 
 | 
 | ||||||
|         if request.app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]: |         if request.app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai", "azure_openai"]: | ||||||
|             if form_data.openai_config is not None: |             if form_data.openai_config is not None: | ||||||
|                 request.app.state.config.RAG_OPENAI_API_BASE_URL = ( |                 request.app.state.config.RAG_OPENAI_API_BASE_URL = ( | ||||||
|                     form_data.openai_config.url |                     form_data.openai_config.url | ||||||
|  | @ -288,6 +302,20 @@ async def update_embedding_config( | ||||||
|                     form_data.ollama_config.key |                     form_data.ollama_config.key | ||||||
|                 ) |                 ) | ||||||
| 
 | 
 | ||||||
|  |             if form_data.azure_openai_config is not None: | ||||||
|  |                 request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = ( | ||||||
|  |                     form_data.azure_openai_config.url | ||||||
|  |                 ) | ||||||
|  |                 request.app.state.config.RAG_AZURE_OPENAI_API_KEY = ( | ||||||
|  |                     form_data.azure_openai_config.key | ||||||
|  |                 ) | ||||||
|  |                 request.app.state.config.RAG_AZURE_OPENAI_DEPLOYMENT = ( | ||||||
|  |                     form_data.azure_openai_config.deployment | ||||||
|  |                 ) | ||||||
|  |                 request.app.state.config.RAG_AZURE_OPENAI_VERSION = ( | ||||||
|  |                     form_data.azure_openai_config.version | ||||||
|  |                 ) | ||||||
|  | 
 | ||||||
|             request.app.state.config.RAG_EMBEDDING_BATCH_SIZE = ( |             request.app.state.config.RAG_EMBEDDING_BATCH_SIZE = ( | ||||||
|                 form_data.embedding_batch_size |                 form_data.embedding_batch_size | ||||||
|             ) |             ) | ||||||
|  | @ -304,14 +332,32 @@ async def update_embedding_config( | ||||||
|             ( |             ( | ||||||
|                 request.app.state.config.RAG_OPENAI_API_BASE_URL |                 request.app.state.config.RAG_OPENAI_API_BASE_URL | ||||||
|                 if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" |                 if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" | ||||||
|                 else request.app.state.config.RAG_OLLAMA_BASE_URL |                 else ( | ||||||
|  |                     request.app.state.config.RAG_OLLAMA_BASE_URL | ||||||
|  |                     if request.app.state.config.RAG_EMBEDDING_ENGINE == "ollama" | ||||||
|  |                     else request.app.state.config.RAG_AZURE_OPENAI_BASE_URL | ||||||
|  |                 ) | ||||||
|             ), |             ), | ||||||
|             ( |             ( | ||||||
|                 request.app.state.config.RAG_OPENAI_API_KEY |                 request.app.state.config.RAG_OPENAI_API_KEY | ||||||
|                 if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" |                 if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" | ||||||
|                 else request.app.state.config.RAG_OLLAMA_API_KEY |                 else ( | ||||||
|  |                     request.app.state.config.RAG_OLLAMA_API_KEY | ||||||
|  |                     if request.app.state.config.RAG_EMBEDDING_ENGINE == "ollama" | ||||||
|  |                     else request.app.state.config.RAG_AZURE_OPENAI_API_KEY | ||||||
|  |                 ) | ||||||
|             ), |             ), | ||||||
|             request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, |             request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, | ||||||
|  |             ( | ||||||
|  |                 request.app.state.config.RAG_AZURE_OPENAI_DEPLOYMENT | ||||||
|  |                 if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai" | ||||||
|  |                 else None | ||||||
|  |             ), | ||||||
|  |             ( | ||||||
|  |                 request.app.state.config.RAG_AZURE_OPENAI_VERSION | ||||||
|  |                 if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai" | ||||||
|  |                 else None | ||||||
|  |             ), | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         return { |         return { | ||||||
|  | @ -327,6 +373,12 @@ async def update_embedding_config( | ||||||
|                 "url": request.app.state.config.RAG_OLLAMA_BASE_URL, |                 "url": request.app.state.config.RAG_OLLAMA_BASE_URL, | ||||||
|                 "key": request.app.state.config.RAG_OLLAMA_API_KEY, |                 "key": request.app.state.config.RAG_OLLAMA_API_KEY, | ||||||
|             }, |             }, | ||||||
|  |             "azure_openai_config": { | ||||||
|  |                 "url": request.app.state.config.RAG_AZURE_OPENAI_BASE_URL, | ||||||
|  |                 "key": request.app.state.config.RAG_AZURE_OPENAI_API_KEY, | ||||||
|  |                 "deployment": request.app.state.config.RAG_AZURE_OPENAI_DEPLOYMENT, | ||||||
|  |                 "version": request.app.state.config.RAG_AZURE_OPENAI_VERSION, | ||||||
|  |             }, | ||||||
|         } |         } | ||||||
|     except Exception as e: |     except Exception as e: | ||||||
|         log.exception(f"Problem updating embedding model: {e}") |         log.exception(f"Problem updating embedding model: {e}") | ||||||
|  | @ -1129,14 +1181,32 @@ def save_docs_to_vector_db( | ||||||
|             ( |             ( | ||||||
|                 request.app.state.config.RAG_OPENAI_API_BASE_URL |                 request.app.state.config.RAG_OPENAI_API_BASE_URL | ||||||
|                 if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" |                 if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" | ||||||
|                 else request.app.state.config.RAG_OLLAMA_BASE_URL |                 else ( | ||||||
|  |                     request.app.state.config.RAG_OLLAMA_BASE_URL | ||||||
|  |                     if request.app.state.config.RAG_EMBEDDING_ENGINE == "ollama" | ||||||
|  |                     else request.app.state.config.RAG_AZURE_OPENAI_BASE_URL | ||||||
|  |                 ) | ||||||
|             ), |             ), | ||||||
|             ( |             ( | ||||||
|                 request.app.state.config.RAG_OPENAI_API_KEY |                 request.app.state.config.RAG_OPENAI_API_KEY | ||||||
|                 if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" |                 if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" | ||||||
|                 else request.app.state.config.RAG_OLLAMA_API_KEY |                 else ( | ||||||
|  |                     request.app.state.config.RAG_OLLAMA_API_KEY | ||||||
|  |                     if request.app.state.config.RAG_EMBEDDING_ENGINE == "ollama" | ||||||
|  |                     else request.app.state.config.RAG_AZURE_OPENAI_API_KEY | ||||||
|  |                 ) | ||||||
|             ), |             ), | ||||||
|             request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, |             request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, | ||||||
|  |             ( | ||||||
|  |                 request.app.state.config.RAG_AZURE_OPENAI_DEPLOYMENT | ||||||
|  |                 if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai" | ||||||
|  |                 else None | ||||||
|  |             ), | ||||||
|  |             ( | ||||||
|  |                 request.app.state.config.RAG_AZURE_OPENAI_VERSION | ||||||
|  |                 if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai" | ||||||
|  |                 else None | ||||||
|  |             ), | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         embeddings = embedding_function( |         embeddings = embedding_function( | ||||||
|  |  | ||||||
|  | @ -180,15 +180,23 @@ export const getEmbeddingConfig = async (token: string) => { | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| type OpenAIConfigForm = { | type OpenAIConfigForm = { | ||||||
| 	key: string; |         key: string; | ||||||
| 	url: string; |         url: string; | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | type AzureOpenAIConfigForm = { | ||||||
|  |         key: string; | ||||||
|  |         url: string; | ||||||
|  |         deployment: string; | ||||||
|  |         version: string; | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| type EmbeddingModelUpdateForm = { | type EmbeddingModelUpdateForm = { | ||||||
| 	openai_config?: OpenAIConfigForm; |         openai_config?: OpenAIConfigForm; | ||||||
| 	embedding_engine: string; |         azure_openai_config?: AzureOpenAIConfigForm; | ||||||
| 	embedding_model: string; |         embedding_engine: string; | ||||||
| 	embedding_batch_size?: number; |         embedding_model: string; | ||||||
|  |         embedding_batch_size?: number; | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| export const updateEmbeddingConfig = async (token: string, payload: EmbeddingModelUpdateForm) => { | export const updateEmbeddingConfig = async (token: string, payload: EmbeddingModelUpdateForm) => { | ||||||
|  |  | ||||||
|  | @ -43,8 +43,13 @@ | ||||||
| 	let embeddingBatchSize = 1; | 	let embeddingBatchSize = 1; | ||||||
| 	let rerankingModel = ''; | 	let rerankingModel = ''; | ||||||
| 
 | 
 | ||||||
| 	let OpenAIUrl = ''; |         let OpenAIUrl = ''; | ||||||
| 	let OpenAIKey = ''; |         let OpenAIKey = ''; | ||||||
|  | 
 | ||||||
|  |         let AzureOpenAIUrl = ''; | ||||||
|  |         let AzureOpenAIKey = ''; | ||||||
|  |         let AzureOpenAIDeployment = ''; | ||||||
|  |         let AzureOpenAIVersion = ''; | ||||||
| 
 | 
 | ||||||
| 	let OllamaUrl = ''; | 	let OllamaUrl = ''; | ||||||
| 	let OllamaKey = ''; | 	let OllamaKey = ''; | ||||||
|  | @ -86,27 +91,40 @@ | ||||||
| 			return; | 			return; | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		if ((embeddingEngine === 'openai' && OpenAIKey === '') || OpenAIUrl === '') { |                 if (embeddingEngine === 'openai' && (OpenAIKey === '' || OpenAIUrl === '')) { | ||||||
| 			toast.error($i18n.t('OpenAI URL/Key required.')); |                         toast.error($i18n.t('OpenAI URL/Key required.')); | ||||||
| 			return; |                         return; | ||||||
| 		} |                 } | ||||||
|  |                 if ( | ||||||
|  |                         embeddingEngine === 'azure_openai' && | ||||||
|  |                         (AzureOpenAIKey === '' || AzureOpenAIUrl === '' || AzureOpenAIDeployment === '' || AzureOpenAIVersion === '') | ||||||
|  |                 ) { | ||||||
|  |                         toast.error($i18n.t('OpenAI URL/Key required.')); | ||||||
|  |                         return; | ||||||
|  |                 } | ||||||
| 
 | 
 | ||||||
| 		console.debug('Update embedding model attempt:', embeddingModel); | 		console.debug('Update embedding model attempt:', embeddingModel); | ||||||
| 
 | 
 | ||||||
| 		updateEmbeddingModelLoading = true; | 		updateEmbeddingModelLoading = true; | ||||||
| 		const res = await updateEmbeddingConfig(localStorage.token, { |                 const res = await updateEmbeddingConfig(localStorage.token, { | ||||||
| 			embedding_engine: embeddingEngine, |                         embedding_engine: embeddingEngine, | ||||||
| 			embedding_model: embeddingModel, |                         embedding_model: embeddingModel, | ||||||
| 			embedding_batch_size: embeddingBatchSize, |                         embedding_batch_size: embeddingBatchSize, | ||||||
| 			ollama_config: { |                         ollama_config: { | ||||||
| 				key: OllamaKey, |                                 key: OllamaKey, | ||||||
| 				url: OllamaUrl |                                 url: OllamaUrl | ||||||
| 			}, |                         }, | ||||||
| 			openai_config: { |                         openai_config: { | ||||||
| 				key: OpenAIKey, |                                 key: OpenAIKey, | ||||||
| 				url: OpenAIUrl |                                 url: OpenAIUrl | ||||||
| 			} |                         }, | ||||||
| 		}).catch(async (error) => { |                         azure_openai_config: { | ||||||
|  |                                 key: AzureOpenAIKey, | ||||||
|  |                                 url: AzureOpenAIUrl, | ||||||
|  |                                 deployment: AzureOpenAIDeployment, | ||||||
|  |                                 version: AzureOpenAIVersion | ||||||
|  |                         } | ||||||
|  |                 }).catch(async (error) => { | ||||||
| 			toast.error(`${error}`); | 			toast.error(`${error}`); | ||||||
| 			await setEmbeddingConfig(); | 			await setEmbeddingConfig(); | ||||||
| 			return null; | 			return null; | ||||||
|  | @ -200,13 +218,18 @@ | ||||||
| 			embeddingModel = embeddingConfig.embedding_model; | 			embeddingModel = embeddingConfig.embedding_model; | ||||||
| 			embeddingBatchSize = embeddingConfig.embedding_batch_size ?? 1; | 			embeddingBatchSize = embeddingConfig.embedding_batch_size ?? 1; | ||||||
| 
 | 
 | ||||||
| 			OpenAIKey = embeddingConfig.openai_config.key; |                         OpenAIKey = embeddingConfig.openai_config.key; | ||||||
| 			OpenAIUrl = embeddingConfig.openai_config.url; |                         OpenAIUrl = embeddingConfig.openai_config.url; | ||||||
| 
 | 
 | ||||||
| 			OllamaKey = embeddingConfig.ollama_config.key; |                         OllamaKey = embeddingConfig.ollama_config.key; | ||||||
| 			OllamaUrl = embeddingConfig.ollama_config.url; |                         OllamaUrl = embeddingConfig.ollama_config.url; | ||||||
| 		} | 
 | ||||||
| 	}; |                         AzureOpenAIKey = embeddingConfig.azure_openai_config.key; | ||||||
|  |                         AzureOpenAIUrl = embeddingConfig.azure_openai_config.url; | ||||||
|  |                         AzureOpenAIDeployment = embeddingConfig.azure_openai_config.deployment; | ||||||
|  |                         AzureOpenAIVersion = embeddingConfig.azure_openai_config.version; | ||||||
|  |                 } | ||||||
|  |         }; | ||||||
| 	onMount(async () => { | 	onMount(async () => { | ||||||
| 		await setEmbeddingConfig(); | 		await setEmbeddingConfig(); | ||||||
| 
 | 
 | ||||||
|  | @ -603,23 +626,26 @@ | ||||||
| 										bind:value={embeddingEngine} | 										bind:value={embeddingEngine} | ||||||
| 										placeholder="Select an embedding model engine" | 										placeholder="Select an embedding model engine" | ||||||
| 										on:change={(e) => { | 										on:change={(e) => { | ||||||
| 											if (e.target.value === 'ollama') { |                                                                         if (e.target.value === 'ollama') { | ||||||
| 												embeddingModel = ''; |                                                                                embeddingModel = ''; | ||||||
| 											} else if (e.target.value === 'openai') { |                                                                        } else if (e.target.value === 'openai') { | ||||||
| 												embeddingModel = 'text-embedding-3-small'; |                                                                                embeddingModel = 'text-embedding-3-small'; | ||||||
| 											} else if (e.target.value === '') { |                                                                        } else if (e.target.value === 'azure_openai') { | ||||||
| 												embeddingModel = 'sentence-transformers/all-MiniLM-L6-v2'; |                                                                                embeddingModel = 'text-embedding-3-small'; | ||||||
| 											} |                                                                        } else if (e.target.value === '') { | ||||||
|  |                                                                                embeddingModel = 'sentence-transformers/all-MiniLM-L6-v2'; | ||||||
|  |                                                                        } | ||||||
| 										}} | 										}} | ||||||
| 									> | 									> | ||||||
| 										<option value="">{$i18n.t('Default (SentenceTransformers)')}</option> | 										<option value="">{$i18n.t('Default (SentenceTransformers)')}</option> | ||||||
| 										<option value="ollama">{$i18n.t('Ollama')}</option> | 										<option value="ollama">{$i18n.t('Ollama')}</option> | ||||||
| 										<option value="openai">{$i18n.t('OpenAI')}</option> |                                                                                <option value="openai">{$i18n.t('OpenAI')}</option> | ||||||
|  |                                                                                <option value="azure_openai">Azure OpenAI</option> | ||||||
| 									</select> | 									</select> | ||||||
| 								</div> | 								</div> | ||||||
| 							</div> | 							</div> | ||||||
| 
 | 
 | ||||||
| 							{#if embeddingEngine === 'openai'} |                                                         {#if embeddingEngine === 'openai'} | ||||||
| 								<div class="my-0.5 flex gap-2 pr-2"> | 								<div class="my-0.5 flex gap-2 pr-2"> | ||||||
| 									<input | 									<input | ||||||
| 										class="flex-1 w-full text-sm bg-transparent outline-hidden" | 										class="flex-1 w-full text-sm bg-transparent outline-hidden" | ||||||
|  | @ -630,7 +656,7 @@ | ||||||
| 
 | 
 | ||||||
| 									<SensitiveInput placeholder={$i18n.t('API Key')} bind:value={OpenAIKey} /> | 									<SensitiveInput placeholder={$i18n.t('API Key')} bind:value={OpenAIKey} /> | ||||||
| 								</div> | 								</div> | ||||||
| 							{:else if embeddingEngine === 'ollama'} |                                                         {:else if embeddingEngine === 'ollama'} | ||||||
| 								<div class="my-0.5 flex gap-2 pr-2"> | 								<div class="my-0.5 flex gap-2 pr-2"> | ||||||
| 									<input | 									<input | ||||||
| 										class="flex-1 w-full text-sm bg-transparent outline-hidden" | 										class="flex-1 w-full text-sm bg-transparent outline-hidden" | ||||||
|  | @ -645,7 +671,33 @@ | ||||||
| 										required={false} | 										required={false} | ||||||
| 									/> | 									/> | ||||||
| 								</div> | 								</div> | ||||||
| 							{/if} |                                                         {:else if embeddingEngine === 'azure_openai'} | ||||||
|  |                                                                 <div class="my-0.5 flex flex-col gap-2 pr-2 w-full"> | ||||||
|  |                                                                         <div class="flex gap-2"> | ||||||
|  |                                                                                 <input | ||||||
|  |                                                                                         class="flex-1 w-full text-sm bg-transparent outline-hidden" | ||||||
|  |                                                                                         placeholder={$i18n.t('API Base URL')} | ||||||
|  |                                                                                         bind:value={AzureOpenAIUrl} | ||||||
|  |                                                                                         required | ||||||
|  |                                                                                 /> | ||||||
|  |                                                                                 <SensitiveInput placeholder={$i18n.t('API Key')} bind:value={AzureOpenAIKey} /> | ||||||
|  |                                                                         </div> | ||||||
|  |                                                                         <div class="flex gap-2"> | ||||||
|  |                                                                                 <input | ||||||
|  |                                                                                         class="flex-1 w-full text-sm bg-transparent outline-hidden" | ||||||
|  |                                                                                         placeholder="Deployment" | ||||||
|  |                                                                                         bind:value={AzureOpenAIDeployment} | ||||||
|  |                                                                                         required | ||||||
|  |                                                                                 /> | ||||||
|  |                                                                                 <input | ||||||
|  |                                                                                         class="flex-1 w-full text-sm bg-transparent outline-hidden" | ||||||
|  |                                                                                         placeholder="Version" | ||||||
|  |                                                                                         bind:value={AzureOpenAIVersion} | ||||||
|  |                                                                                         required | ||||||
|  |                                                                                 /> | ||||||
|  |                                                                         </div> | ||||||
|  |                                                                 </div> | ||||||
|  |                                                         {/if} | ||||||
| 						</div> | 						</div> | ||||||
| 
 | 
 | ||||||
| 						<div class="  mb-2.5 flex flex-col w-full"> | 						<div class="  mb-2.5 flex flex-col w-full"> | ||||||
|  | @ -741,7 +793,7 @@ | ||||||
| 							</div> | 							</div> | ||||||
| 						</div> | 						</div> | ||||||
| 
 | 
 | ||||||
| 						{#if embeddingEngine === 'ollama' || embeddingEngine === 'openai'} |                                                 {#if embeddingEngine === 'ollama' || embeddingEngine === 'openai' || embeddingEngine === 'azure_openai'} | ||||||
| 							<div class="  mb-2.5 flex w-full justify-between"> | 							<div class="  mb-2.5 flex w-full justify-between"> | ||||||
| 								<div class=" self-center text-xs font-medium"> | 								<div class=" self-center text-xs font-medium"> | ||||||
| 									{$i18n.t('Embedding Batch Size')} | 									{$i18n.t('Embedding Batch Size')} | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue