| 
									
										
										
										
											2025-02-20 22:15:38 +08:00
										 |  |  |  | from contextlib import asynccontextmanager | 
					
						
							|  |  |  |  | from dataclasses import asdict, dataclass | 
					
						
							|  |  |  |  | from enum import Enum | 
					
						
							|  |  |  |  | import re | 
					
						
							|  |  |  |  | from typing import ( | 
					
						
							|  |  |  |  |     TYPE_CHECKING, | 
					
						
							|  |  |  |  |     Any, | 
					
						
							|  |  |  |  |     AsyncGenerator, | 
					
						
							|  |  |  |  |     Dict, | 
					
						
							|  |  |  |  |     MutableMapping, | 
					
						
							|  |  |  |  |     Optional, | 
					
						
							|  |  |  |  |     cast, | 
					
						
							|  |  |  |  | ) | 
					
						
							|  |  |  |  | import uuid | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | from asgiref.typing import ( | 
					
						
							|  |  |  |  |     ASGI3Application, | 
					
						
							|  |  |  |  |     ASGIReceiveCallable, | 
					
						
							|  |  |  |  |     ASGIReceiveEvent, | 
					
						
							|  |  |  |  |     ASGISendCallable, | 
					
						
							|  |  |  |  |     ASGISendEvent, | 
					
						
							|  |  |  |  |     Scope as ASGIScope, | 
					
						
							|  |  |  |  | ) | 
					
						
							|  |  |  |  | from loguru import logger | 
					
						
							|  |  |  |  | from starlette.requests import Request | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | from open_webui.env import AUDIT_LOG_LEVEL, MAX_BODY_LOG_SIZE | 
					
						
							|  |  |  |  | from open_webui.utils.auth import get_current_user, get_http_authorization_cred | 
					
						
							|  |  |  |  | from open_webui.models.users import UserModel | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | if TYPE_CHECKING: | 
					
						
							|  |  |  |  |     from loguru import Logger | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | @dataclass(frozen=True) | 
					
						
							|  |  |  |  | class AuditLogEntry: | 
					
						
							|  |  |  |  |     # `Metadata` audit level properties | 
					
						
							|  |  |  |  |     id: str | 
					
						
							|  |  |  |  |     user: dict[str, Any] | 
					
						
							|  |  |  |  |     audit_level: str | 
					
						
							|  |  |  |  |     verb: str | 
					
						
							|  |  |  |  |     request_uri: str | 
					
						
							|  |  |  |  |     user_agent: Optional[str] = None | 
					
						
							|  |  |  |  |     source_ip: Optional[str] = None | 
					
						
							|  |  |  |  |     # `Request` audit level properties | 
					
						
							|  |  |  |  |     request_object: Any = None | 
					
						
							|  |  |  |  |     # `Request Response` level | 
					
						
							|  |  |  |  |     response_object: Any = None | 
					
						
							|  |  |  |  |     response_status_code: Optional[int] = None | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | class AuditLevel(str, Enum): | 
					
						
							|  |  |  |  |     NONE = "NONE" | 
					
						
							|  |  |  |  |     METADATA = "METADATA" | 
					
						
							|  |  |  |  |     REQUEST = "REQUEST" | 
					
						
							|  |  |  |  |     REQUEST_RESPONSE = "REQUEST_RESPONSE" | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | class AuditLogger: | 
					
						
							|  |  |  |  |     """
 | 
					
						
							|  |  |  |  |     A helper class that encapsulates audit logging functionality. It uses Loguru’s logger with an auditable binding to ensure that audit log entries are filtered correctly. | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     Parameters: | 
					
						
							|  |  |  |  |     logger (Logger): An instance of Loguru’s logger. | 
					
						
							|  |  |  |  |     """
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     def __init__(self, logger: "Logger"): | 
					
						
							|  |  |  |  |         self.logger = logger.bind(auditable=True) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     def write( | 
					
						
							|  |  |  |  |         self, | 
					
						
							|  |  |  |  |         audit_entry: AuditLogEntry, | 
					
						
							|  |  |  |  |         *, | 
					
						
							|  |  |  |  |         log_level: str = "INFO", | 
					
						
							|  |  |  |  |         extra: Optional[dict] = None, | 
					
						
							|  |  |  |  |     ): | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         entry = asdict(audit_entry) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         if extra: | 
					
						
							|  |  |  |  |             entry["extra"] = extra | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         self.logger.log( | 
					
						
							|  |  |  |  |             log_level, | 
					
						
							|  |  |  |  |             "", | 
					
						
							|  |  |  |  |             **entry, | 
					
						
							|  |  |  |  |         ) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | class AuditContext: | 
					
						
							|  |  |  |  |     """
 | 
					
						
							|  |  |  |  |     Captures and aggregates the HTTP request and response bodies during the processing of a request. It ensures that only a configurable maximum amount of data is stored to prevent excessive memory usage. | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     Attributes: | 
					
						
							|  |  |  |  |     request_body (bytearray): Accumulated request payload. | 
					
						
							|  |  |  |  |     response_body (bytearray): Accumulated response payload. | 
					
						
							|  |  |  |  |     max_body_size (int): Maximum number of bytes to capture. | 
					
						
							|  |  |  |  |     metadata (Dict[str, Any]): A dictionary to store additional audit metadata (user, http verb, user agent, etc.). | 
					
						
							|  |  |  |  |     """
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     def __init__(self, max_body_size: int = MAX_BODY_LOG_SIZE): | 
					
						
							|  |  |  |  |         self.request_body = bytearray() | 
					
						
							|  |  |  |  |         self.response_body = bytearray() | 
					
						
							|  |  |  |  |         self.max_body_size = max_body_size | 
					
						
							|  |  |  |  |         self.metadata: Dict[str, Any] = {} | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     def add_request_chunk(self, chunk: bytes): | 
					
						
							|  |  |  |  |         if len(self.request_body) < self.max_body_size: | 
					
						
							|  |  |  |  |             self.request_body.extend( | 
					
						
							|  |  |  |  |                 chunk[: self.max_body_size - len(self.request_body)] | 
					
						
							|  |  |  |  |             ) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     def add_response_chunk(self, chunk: bytes): | 
					
						
							|  |  |  |  |         if len(self.response_body) < self.max_body_size: | 
					
						
							|  |  |  |  |             self.response_body.extend( | 
					
						
							|  |  |  |  |                 chunk[: self.max_body_size - len(self.response_body)] | 
					
						
							|  |  |  |  |             ) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | class AuditLoggingMiddleware: | 
					
						
							|  |  |  |  |     """
 | 
					
						
							|  |  |  |  |     ASGI middleware that intercepts HTTP requests and responses to perform audit logging. It captures request/response bodies (depending on audit level), headers, HTTP methods, and user information, then logs a structured audit entry at the end of the request cycle. | 
					
						
							|  |  |  |  |     """
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     AUDITED_METHODS = {"PUT", "PATCH", "DELETE", "POST"} | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     def __init__( | 
					
						
							|  |  |  |  |         self, | 
					
						
							|  |  |  |  |         app: ASGI3Application, | 
					
						
							|  |  |  |  |         *, | 
					
						
							|  |  |  |  |         excluded_paths: Optional[list[str]] = None, | 
					
						
							|  |  |  |  |         max_body_size: int = MAX_BODY_LOG_SIZE, | 
					
						
							|  |  |  |  |         audit_level: AuditLevel = AuditLevel.NONE, | 
					
						
							|  |  |  |  |     ) -> None: | 
					
						
							|  |  |  |  |         self.app = app | 
					
						
							|  |  |  |  |         self.audit_logger = AuditLogger(logger) | 
					
						
							|  |  |  |  |         self.excluded_paths = excluded_paths or [] | 
					
						
							|  |  |  |  |         self.max_body_size = max_body_size | 
					
						
							|  |  |  |  |         self.audit_level = audit_level | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     async def __call__( | 
					
						
							|  |  |  |  |         self, | 
					
						
							|  |  |  |  |         scope: ASGIScope, | 
					
						
							|  |  |  |  |         receive: ASGIReceiveCallable, | 
					
						
							|  |  |  |  |         send: ASGISendCallable, | 
					
						
							|  |  |  |  |     ) -> None: | 
					
						
							|  |  |  |  |         if scope["type"] != "http": | 
					
						
							|  |  |  |  |             return await self.app(scope, receive, send) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         request = Request(scope=cast(MutableMapping, scope)) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         if self._should_skip_auditing(request): | 
					
						
							|  |  |  |  |             return await self.app(scope, receive, send) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         async with self._audit_context(request) as context: | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |             async def send_wrapper(message: ASGISendEvent) -> None: | 
					
						
							|  |  |  |  |                 if self.audit_level == AuditLevel.REQUEST_RESPONSE: | 
					
						
							|  |  |  |  |                     await self._capture_response(message, context) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |                 await send(message) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |             original_receive = receive | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |             async def receive_wrapper() -> ASGIReceiveEvent: | 
					
						
							|  |  |  |  |                 nonlocal original_receive | 
					
						
							|  |  |  |  |                 message = await original_receive() | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |                 if self.audit_level in ( | 
					
						
							|  |  |  |  |                     AuditLevel.REQUEST, | 
					
						
							|  |  |  |  |                     AuditLevel.REQUEST_RESPONSE, | 
					
						
							|  |  |  |  |                 ): | 
					
						
							|  |  |  |  |                     await self._capture_request(message, context) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |                 return message | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |             await self.app(scope, receive_wrapper, send_wrapper) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     @asynccontextmanager | 
					
						
							|  |  |  |  |     async def _audit_context( | 
					
						
							|  |  |  |  |         self, request: Request | 
					
						
							|  |  |  |  |     ) -> AsyncGenerator[AuditContext, None]: | 
					
						
							|  |  |  |  |         """
 | 
					
						
							|  |  |  |  |         async context manager that ensures that an audit log entry is recorded after the request is processed. | 
					
						
							|  |  |  |  |         """
 | 
					
						
							|  |  |  |  |         context = AuditContext() | 
					
						
							|  |  |  |  |         try: | 
					
						
							|  |  |  |  |             yield context | 
					
						
							|  |  |  |  |         finally: | 
					
						
							|  |  |  |  |             await self._log_audit_entry(request, context) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     async def _get_authenticated_user(self, request: Request) -> UserModel: | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         auth_header = request.headers.get("Authorization") | 
					
						
							|  |  |  |  |         assert auth_header | 
					
						
							| 
									
										
										
										
											2025-02-27 15:35:09 +08:00
										 |  |  |  |         user = get_current_user(request, None, get_http_authorization_cred(auth_header)) | 
					
						
							| 
									
										
										
										
											2025-02-20 22:15:38 +08:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  |         return user | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     def _should_skip_auditing(self, request: Request) -> bool: | 
					
						
							|  |  |  |  |         if ( | 
					
						
							|  |  |  |  |             request.method not in {"POST", "PUT", "PATCH", "DELETE"} | 
					
						
							|  |  |  |  |             or AUDIT_LOG_LEVEL == "NONE" | 
					
						
							|  |  |  |  |             or not request.headers.get("authorization") | 
					
						
							|  |  |  |  |         ): | 
					
						
							|  |  |  |  |             return True | 
					
						
							|  |  |  |  |         # match either /api/<resource>/...(for the endpoint /api/chat case) or /api/v1/<resource>/... | 
					
						
							|  |  |  |  |         pattern = re.compile( | 
					
						
							|  |  |  |  |             r"^/api(?:/v1)?/(" + "|".join(self.excluded_paths) + r")\b" | 
					
						
							|  |  |  |  |         ) | 
					
						
							|  |  |  |  |         if pattern.match(request.url.path): | 
					
						
							|  |  |  |  |             return True | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         return False | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     async def _capture_request(self, message: ASGIReceiveEvent, context: AuditContext): | 
					
						
							|  |  |  |  |         if message["type"] == "http.request": | 
					
						
							|  |  |  |  |             body = message.get("body", b"") | 
					
						
							|  |  |  |  |             context.add_request_chunk(body) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     async def _capture_response(self, message: ASGISendEvent, context: AuditContext): | 
					
						
							|  |  |  |  |         if message["type"] == "http.response.start": | 
					
						
							|  |  |  |  |             context.metadata["response_status_code"] = message["status"] | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         elif message["type"] == "http.response.body": | 
					
						
							|  |  |  |  |             body = message.get("body", b"") | 
					
						
							|  |  |  |  |             context.add_response_chunk(body) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     async def _log_audit_entry(self, request: Request, context: AuditContext): | 
					
						
							|  |  |  |  |         try: | 
					
						
							|  |  |  |  |             user = await self._get_authenticated_user(request) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |             entry = AuditLogEntry( | 
					
						
							|  |  |  |  |                 id=str(uuid.uuid4()), | 
					
						
							|  |  |  |  |                 user=user.model_dump(include={"id", "name", "email", "role"}), | 
					
						
							|  |  |  |  |                 audit_level=self.audit_level.value, | 
					
						
							|  |  |  |  |                 verb=request.method, | 
					
						
							|  |  |  |  |                 request_uri=str(request.url), | 
					
						
							|  |  |  |  |                 response_status_code=context.metadata.get("response_status_code", None), | 
					
						
							|  |  |  |  |                 source_ip=request.client.host if request.client else None, | 
					
						
							|  |  |  |  |                 user_agent=request.headers.get("user-agent"), | 
					
						
							|  |  |  |  |                 request_object=context.request_body.decode("utf-8", errors="replace"), | 
					
						
							|  |  |  |  |                 response_object=context.response_body.decode("utf-8", errors="replace"), | 
					
						
							|  |  |  |  |             ) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |             self.audit_logger.write(entry) | 
					
						
							|  |  |  |  |         except Exception as e: | 
					
						
							|  |  |  |  |             logger.error(f"Failed to log audit entry: {str(e)}") |