| 
									
										
										
										
											2025-04-01 22:26:32 +08:00
										 |  |  | import requests | 
					
						
							|  |  |  | import logging | 
					
						
							|  |  |  | import os | 
					
						
							|  |  |  | import sys | 
					
						
							|  |  |  | from typing import List, Dict, Any | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from langchain_core.documents import Document | 
					
						
							|  |  |  | from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) | 
					
						
							|  |  |  | log = logging.getLogger(__name__) | 
					
						
							|  |  |  | log.setLevel(SRC_LOG_LEVELS["RAG"]) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class MistralLoader: | 
					
						
							|  |  |  |     """
 | 
					
						
							| 
									
										
										
										
											2025-04-02 01:14:26 +08:00
										 |  |  |     Loads documents by processing them through the Mistral OCR API. | 
					
						
							| 
									
										
										
										
											2025-04-01 22:26:32 +08:00
										 |  |  |     """
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     BASE_API_URL = "https://api.mistral.ai/v1" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __init__(self, api_key: str, file_path: str): | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Initializes the loader. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         Args: | 
					
						
							|  |  |  |             api_key: Your Mistral API key. | 
					
						
							|  |  |  |             file_path: The local path to the PDF file to process. | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         if not api_key: | 
					
						
							|  |  |  |             raise ValueError("API key cannot be empty.") | 
					
						
							|  |  |  |         if not os.path.exists(file_path): | 
					
						
							|  |  |  |             raise FileNotFoundError(f"File not found at {file_path}") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.api_key = api_key | 
					
						
							|  |  |  |         self.file_path = file_path | 
					
						
							|  |  |  |         self.headers = {"Authorization": f"Bearer {self.api_key}"} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _handle_response(self, response: requests.Response) -> Dict[str, Any]: | 
					
						
							|  |  |  |         """Checks response status and returns JSON content.""" | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             response.raise_for_status()  # Raises HTTPError for bad responses (4xx or 5xx) | 
					
						
							|  |  |  |             # Handle potential empty responses for certain successful requests (e.g., DELETE) | 
					
						
							|  |  |  |             if response.status_code == 204 or not response.content: | 
					
						
							|  |  |  |                 return {}  # Return empty dict if no content | 
					
						
							|  |  |  |             return response.json() | 
					
						
							|  |  |  |         except requests.exceptions.HTTPError as http_err: | 
					
						
							|  |  |  |             log.error(f"HTTP error occurred: {http_err} - Response: {response.text}") | 
					
						
							|  |  |  |             raise | 
					
						
							|  |  |  |         except requests.exceptions.RequestException as req_err: | 
					
						
							|  |  |  |             log.error(f"Request exception occurred: {req_err}") | 
					
						
							|  |  |  |             raise | 
					
						
							|  |  |  |         except ValueError as json_err:  # Includes JSONDecodeError | 
					
						
							|  |  |  |             log.error(f"JSON decode error: {json_err} - Response: {response.text}") | 
					
						
							|  |  |  |             raise  # Re-raise after logging | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _upload_file(self) -> str: | 
					
						
							|  |  |  |         """Uploads the file to Mistral for OCR processing.""" | 
					
						
							|  |  |  |         log.info("Uploading file to Mistral API") | 
					
						
							|  |  |  |         url = f"{self.BASE_API_URL}/files" | 
					
						
							|  |  |  |         file_name = os.path.basename(self.file_path) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             with open(self.file_path, "rb") as f: | 
					
						
							|  |  |  |                 files = {"file": (file_name, f, "application/pdf")} | 
					
						
							|  |  |  |                 data = {"purpose": "ocr"} | 
					
						
							| 
									
										
										
										
											2025-04-02 01:14:26 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-04-01 22:26:32 +08:00
										 |  |  |                 upload_headers = self.headers.copy()  # Avoid modifying self.headers | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 response = requests.post( | 
					
						
							|  |  |  |                     url, headers=upload_headers, files=files, data=data | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             response_data = self._handle_response(response) | 
					
						
							|  |  |  |             file_id = response_data.get("id") | 
					
						
							|  |  |  |             if not file_id: | 
					
						
							|  |  |  |                 raise ValueError("File ID not found in upload response.") | 
					
						
							|  |  |  |             log.info(f"File uploaded successfully. File ID: {file_id}") | 
					
						
							|  |  |  |             return file_id | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							|  |  |  |             log.error(f"Failed to upload file: {e}") | 
					
						
							|  |  |  |             raise | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _get_signed_url(self, file_id: str) -> str: | 
					
						
							|  |  |  |         """Retrieves a temporary signed URL for the uploaded file.""" | 
					
						
							|  |  |  |         log.info(f"Getting signed URL for file ID: {file_id}") | 
					
						
							|  |  |  |         url = f"{self.BASE_API_URL}/files/{file_id}/url" | 
					
						
							| 
									
										
										
										
											2025-04-02 01:14:26 +08:00
										 |  |  |         params = {"expiry": 1} | 
					
						
							| 
									
										
										
										
											2025-04-01 22:26:32 +08:00
										 |  |  |         signed_url_headers = {**self.headers, "Accept": "application/json"} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             response = requests.get(url, headers=signed_url_headers, params=params) | 
					
						
							|  |  |  |             response_data = self._handle_response(response) | 
					
						
							|  |  |  |             signed_url = response_data.get("url") | 
					
						
							|  |  |  |             if not signed_url: | 
					
						
							|  |  |  |                 raise ValueError("Signed URL not found in response.") | 
					
						
							|  |  |  |             log.info("Signed URL received.") | 
					
						
							|  |  |  |             return signed_url | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							|  |  |  |             log.error(f"Failed to get signed URL: {e}") | 
					
						
							|  |  |  |             raise | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _process_ocr(self, signed_url: str) -> Dict[str, Any]: | 
					
						
							|  |  |  |         """Sends the signed URL to the OCR endpoint for processing.""" | 
					
						
							|  |  |  |         log.info("Processing OCR via Mistral API") | 
					
						
							|  |  |  |         url = f"{self.BASE_API_URL}/ocr" | 
					
						
							|  |  |  |         ocr_headers = { | 
					
						
							|  |  |  |             **self.headers, | 
					
						
							|  |  |  |             "Content-Type": "application/json", | 
					
						
							|  |  |  |             "Accept": "application/json", | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         payload = { | 
					
						
							|  |  |  |             "model": "mistral-ocr-latest", | 
					
						
							|  |  |  |             "document": { | 
					
						
							|  |  |  |                 "type": "document_url", | 
					
						
							|  |  |  |                 "document_url": signed_url, | 
					
						
							|  |  |  |             }, | 
					
						
							| 
									
										
										
										
											2025-04-02 01:14:26 +08:00
										 |  |  |             "include_image_base64": False, | 
					
						
							| 
									
										
										
										
											2025-04-01 22:26:32 +08:00
										 |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             response = requests.post(url, headers=ocr_headers, json=payload) | 
					
						
							|  |  |  |             ocr_response = self._handle_response(response) | 
					
						
							|  |  |  |             log.info("OCR processing done.") | 
					
						
							|  |  |  |             log.debug("OCR response: %s", ocr_response) | 
					
						
							|  |  |  |             return ocr_response | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							|  |  |  |             log.error(f"Failed during OCR processing: {e}") | 
					
						
							|  |  |  |             raise | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _delete_file(self, file_id: str) -> None: | 
					
						
							|  |  |  |         """Deletes the file from Mistral storage.""" | 
					
						
							|  |  |  |         log.info(f"Deleting uploaded file ID: {file_id}") | 
					
						
							|  |  |  |         url = f"{self.BASE_API_URL}/files/{file_id}" | 
					
						
							|  |  |  |         # No specific Accept header needed, default or Authorization is usually sufficient | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             response = requests.delete(url, headers=self.headers) | 
					
						
							|  |  |  |             delete_response = self._handle_response( | 
					
						
							|  |  |  |                 response | 
					
						
							|  |  |  |             )  # Check status, ignore response body unless needed | 
					
						
							|  |  |  |             log.info( | 
					
						
							|  |  |  |                 f"File deleted successfully: {delete_response}" | 
					
						
							|  |  |  |             )  # Log the response if available | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							|  |  |  |             # Log error but don't necessarily halt execution if deletion fails | 
					
						
							|  |  |  |             log.error(f"Failed to delete file ID {file_id}: {e}") | 
					
						
							|  |  |  |             # Depending on requirements, you might choose to raise the error here | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def load(self) -> List[Document]: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Executes the full OCR workflow: upload, get URL, process OCR, delete file. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         Returns: | 
					
						
							|  |  |  |             A list of Document objects, one for each page processed. | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         file_id = None | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             # 1. Upload file | 
					
						
							|  |  |  |             file_id = self._upload_file() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # 2. Get Signed URL | 
					
						
							|  |  |  |             signed_url = self._get_signed_url(file_id) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # 3. Process OCR | 
					
						
							|  |  |  |             ocr_response = self._process_ocr(signed_url) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # 4. Process results | 
					
						
							|  |  |  |             pages_data = ocr_response.get("pages") | 
					
						
							|  |  |  |             if not pages_data: | 
					
						
							|  |  |  |                 log.warning("No pages found in OCR response.") | 
					
						
							|  |  |  |                 return [Document(page_content="No text content found", metadata={})] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             documents = [] | 
					
						
							|  |  |  |             total_pages = len(pages_data) | 
					
						
							|  |  |  |             for page_data in pages_data: | 
					
						
							|  |  |  |                 page_content = page_data.get("markdown") | 
					
						
							|  |  |  |                 page_index = page_data.get("index")  # API uses 0-based index | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 if page_content is not None and page_index is not None: | 
					
						
							|  |  |  |                     documents.append( | 
					
						
							|  |  |  |                         Document( | 
					
						
							|  |  |  |                             page_content=page_content, | 
					
						
							|  |  |  |                             metadata={ | 
					
						
							|  |  |  |                                 "page": page_index,  # 0-based index from API | 
					
						
							|  |  |  |                                 "page_label": page_index | 
					
						
							|  |  |  |                                 + 1,  # 1-based label for convenience | 
					
						
							|  |  |  |                                 "total_pages": total_pages, | 
					
						
							|  |  |  |                                 # Add other relevant metadata from page_data if available/needed | 
					
						
							|  |  |  |                                 # e.g., page_data.get('width'), page_data.get('height') | 
					
						
							|  |  |  |                             }, | 
					
						
							|  |  |  |                         ) | 
					
						
							|  |  |  |                     ) | 
					
						
							|  |  |  |                 else: | 
					
						
							|  |  |  |                     log.warning( | 
					
						
							|  |  |  |                         f"Skipping page due to missing 'markdown' or 'index'. Data: {page_data}" | 
					
						
							|  |  |  |                     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             if not documents: | 
					
						
							|  |  |  |                 # Case where pages existed but none had valid markdown/index | 
					
						
							|  |  |  |                 log.warning( | 
					
						
							|  |  |  |                     "OCR response contained pages, but none had valid content/index." | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |                 return [ | 
					
						
							|  |  |  |                     Document( | 
					
						
							|  |  |  |                         page_content="No text content found in valid pages", metadata={} | 
					
						
							|  |  |  |                     ) | 
					
						
							|  |  |  |                 ] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             return documents | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							|  |  |  |             log.error(f"An error occurred during the loading process: {e}") | 
					
						
							|  |  |  |             # Return an empty list or a specific error document on failure | 
					
						
							|  |  |  |             return [Document(page_content=f"Error during processing: {e}", metadata={})] | 
					
						
							|  |  |  |         finally: | 
					
						
							|  |  |  |             # 5. Delete file (attempt even if prior steps failed after upload) | 
					
						
							|  |  |  |             if file_id: | 
					
						
							|  |  |  |                 try: | 
					
						
							|  |  |  |                     self._delete_file(file_id) | 
					
						
							|  |  |  |                 except Exception as del_e: | 
					
						
							|  |  |  |                     # Log deletion error, but don't overwrite original error if one occurred | 
					
						
							|  |  |  |                     log.error( | 
					
						
							|  |  |  |                         f"Cleanup error: Could not delete file ID {file_id}. Reason: {del_e}" | 
					
						
							|  |  |  |                     ) |