| 
									
										
										
										
											2024-10-16 15:42:47 +08:00
										 |  |  | import base64 | 
					
						
							| 
									
										
										
										
											2025-09-25 13:28:13 +08:00
										 |  |  | import hashlib | 
					
						
							| 
									
										
										
										
											2024-10-16 22:32:57 +08:00
										 |  |  | import logging | 
					
						
							| 
									
										
										
										
											2024-10-16 15:42:47 +08:00
										 |  |  | import mimetypes | 
					
						
							| 
									
										
										
										
											2025-02-08 03:53:25 +08:00
										 |  |  | import sys | 
					
						
							| 
									
										
										
										
											2025-09-25 13:28:13 +08:00
										 |  |  | import urllib | 
					
						
							| 
									
										
										
										
											2024-10-16 15:42:47 +08:00
										 |  |  | import uuid | 
					
						
							| 
									
										
										
										
											2025-05-02 18:47:02 +08:00
										 |  |  | import json | 
					
						
							| 
									
										
										
										
											2025-09-08 22:05:43 +08:00
										 |  |  | from datetime import datetime, timedelta | 
					
						
							| 
									
										
										
										
											2024-10-16 15:42:47 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-03 22:50:02 +08:00
										 |  |  | import re | 
					
						
							|  |  |  | import fnmatch | 
					
						
							| 
									
										
										
										
											2025-09-08 22:05:43 +08:00
										 |  |  | import time | 
					
						
							| 
									
										
										
										
											2025-09-25 13:28:13 +08:00
										 |  |  | import secrets | 
					
						
							|  |  |  | from cryptography.fernet import Fernet | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-03 22:50:02 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-16 15:42:47 +08:00
										 |  |  | import aiohttp | 
					
						
							| 
									
										
										
										
											2024-10-16 22:32:57 +08:00
										 |  |  | from authlib.integrations.starlette_client import OAuth | 
					
						
							|  |  |  | from authlib.oidc.core import UserInfo | 
					
						
							| 
									
										
										
										
											2024-10-16 15:42:47 +08:00
										 |  |  | from fastapi import ( | 
					
						
							|  |  |  |     HTTPException, | 
					
						
							|  |  |  |     status, | 
					
						
							|  |  |  | ) | 
					
						
							| 
									
										
										
										
											2024-10-16 22:32:57 +08:00
										 |  |  | from starlette.responses import RedirectResponse | 
					
						
							| 
									
										
										
										
											2025-09-25 13:28:13 +08:00
										 |  |  | from typing import Optional | 
					
						
							| 
									
										
										
										
											2024-10-16 15:42:47 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-08 22:05:43 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-10 16:54:13 +08:00
										 |  |  | from open_webui.models.auths import Auths | 
					
						
							| 
									
										
										
										
											2025-09-08 22:05:43 +08:00
										 |  |  | from open_webui.models.oauth_sessions import OAuthSessions | 
					
						
							| 
									
										
										
										
											2024-12-10 16:54:13 +08:00
										 |  |  | from open_webui.models.users import Users | 
					
						
							| 
									
										
										
										
											2025-09-08 22:05:43 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-04-19 01:17:08 +08:00
										 |  |  | from open_webui.models.groups import Groups, GroupModel, GroupUpdateForm, GroupForm | 
					
						
							| 
									
										
										
										
											2024-10-16 15:42:47 +08:00
										 |  |  | from open_webui.config import ( | 
					
						
							|  |  |  |     DEFAULT_USER_ROLE, | 
					
						
							|  |  |  |     ENABLE_OAUTH_SIGNUP, | 
					
						
							|  |  |  |     OAUTH_MERGE_ACCOUNTS_BY_EMAIL, | 
					
						
							|  |  |  |     OAUTH_PROVIDERS, | 
					
						
							|  |  |  |     ENABLE_OAUTH_ROLE_MANAGEMENT, | 
					
						
							| 
									
										
										
										
											2024-12-18 03:38:07 +08:00
										 |  |  |     ENABLE_OAUTH_GROUP_MANAGEMENT, | 
					
						
							| 
									
										
										
										
											2025-04-19 01:17:08 +08:00
										 |  |  |     ENABLE_OAUTH_GROUP_CREATION, | 
					
						
							| 
									
										
										
										
											2025-05-02 18:47:02 +08:00
										 |  |  |     OAUTH_BLOCKED_GROUPS, | 
					
						
							| 
									
										
										
										
											2024-10-16 15:42:47 +08:00
										 |  |  |     OAUTH_ROLES_CLAIM, | 
					
						
							| 
									
										
										
										
											2025-08-09 04:46:14 +08:00
										 |  |  |     OAUTH_SUB_CLAIM, | 
					
						
							| 
									
										
										
										
											2024-12-18 03:38:07 +08:00
										 |  |  |     OAUTH_GROUPS_CLAIM, | 
					
						
							| 
									
										
										
										
											2024-10-16 15:42:47 +08:00
										 |  |  |     OAUTH_EMAIL_CLAIM, | 
					
						
							|  |  |  |     OAUTH_PICTURE_CLAIM, | 
					
						
							|  |  |  |     OAUTH_USERNAME_CLAIM, | 
					
						
							|  |  |  |     OAUTH_ALLOWED_ROLES, | 
					
						
							| 
									
										
										
										
											2024-10-21 09:38:06 +08:00
										 |  |  |     OAUTH_ADMIN_ROLES, | 
					
						
							| 
									
										
										
										
											2024-12-02 16:36:56 +08:00
										 |  |  |     OAUTH_ALLOWED_DOMAINS, | 
					
						
							| 
									
										
										
										
											2025-05-07 00:00:35 +08:00
										 |  |  |     OAUTH_UPDATE_PICTURE_ON_LOGIN, | 
					
						
							| 
									
										
										
										
											2024-10-21 09:38:06 +08:00
										 |  |  |     WEBHOOK_URL, | 
					
						
							|  |  |  |     JWT_EXPIRES_IN, | 
					
						
							|  |  |  |     AppConfig, | 
					
						
							| 
									
										
										
										
											2024-10-16 15:42:47 +08:00
										 |  |  | ) | 
					
						
							| 
									
										
										
										
											2025-01-08 16:38:00 +08:00
										 |  |  | from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES | 
					
						
							| 
									
										
										
										
											2025-02-16 16:11:18 +08:00
										 |  |  | from open_webui.env import ( | 
					
						
							| 
									
										
										
										
											2025-05-15 03:33:52 +08:00
										 |  |  |     AIOHTTP_CLIENT_SESSION_SSL, | 
					
						
							| 
									
										
										
										
											2025-02-16 16:11:18 +08:00
										 |  |  |     WEBUI_NAME, | 
					
						
							|  |  |  |     WEBUI_AUTH_COOKIE_SAME_SITE, | 
					
						
							|  |  |  |     WEBUI_AUTH_COOKIE_SECURE, | 
					
						
							| 
									
										
										
										
											2025-09-08 22:05:43 +08:00
										 |  |  |     ENABLE_OAUTH_ID_TOKEN_COOKIE, | 
					
						
							| 
									
										
										
										
											2025-09-25 13:28:13 +08:00
										 |  |  |     OAUTH_CLIENT_INFO_ENCRYPTION_KEY, | 
					
						
							| 
									
										
										
										
											2025-02-16 16:11:18 +08:00
										 |  |  | ) | 
					
						
							| 
									
										
										
										
											2024-10-16 15:42:47 +08:00
										 |  |  | from open_webui.utils.misc import parse_duration | 
					
						
							| 
									
										
										
										
											2024-12-09 08:01:56 +08:00
										 |  |  | from open_webui.utils.auth import get_password_hash, create_token | 
					
						
							| 
									
										
										
										
											2024-10-16 15:42:47 +08:00
										 |  |  | from open_webui.utils.webhook import post_webhook | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-25 13:28:13 +08:00
										 |  |  | from mcp.shared.auth import ( | 
					
						
							|  |  |  |     OAuthClientMetadata, | 
					
						
							|  |  |  |     OAuthMetadata, | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class OAuthClientInformationFull(OAuthClientMetadata): | 
					
						
							|  |  |  |     issuer: Optional[str] = None  # URL of the OAuth server that issued this client | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     client_id: str | 
					
						
							|  |  |  |     client_secret: str | None = None | 
					
						
							|  |  |  |     client_id_issued_at: int | None = None | 
					
						
							|  |  |  |     client_secret_expires_at: int | None = None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-08 03:53:25 +08:00
										 |  |  | from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) | 
					
						
							| 
									
										
										
										
											2024-10-16 15:42:47 +08:00
										 |  |  | log = logging.getLogger(__name__) | 
					
						
							| 
									
										
										
										
											2025-02-08 03:53:25 +08:00
										 |  |  | log.setLevel(SRC_LOG_LEVELS["OAUTH"]) | 
					
						
							| 
									
										
										
										
											2024-10-16 15:42:47 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-16 22:32:57 +08:00
										 |  |  | auth_manager_config = AppConfig() | 
					
						
							|  |  |  | auth_manager_config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE | 
					
						
							|  |  |  | auth_manager_config.ENABLE_OAUTH_SIGNUP = ENABLE_OAUTH_SIGNUP | 
					
						
							|  |  |  | auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL = OAUTH_MERGE_ACCOUNTS_BY_EMAIL | 
					
						
							|  |  |  | auth_manager_config.ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT | 
					
						
							| 
									
										
										
										
											2024-12-18 03:38:07 +08:00
										 |  |  | auth_manager_config.ENABLE_OAUTH_GROUP_MANAGEMENT = ENABLE_OAUTH_GROUP_MANAGEMENT | 
					
						
							| 
									
										
										
										
											2025-04-19 01:17:08 +08:00
										 |  |  | auth_manager_config.ENABLE_OAUTH_GROUP_CREATION = ENABLE_OAUTH_GROUP_CREATION | 
					
						
							| 
									
										
										
										
											2025-05-02 18:47:02 +08:00
										 |  |  | auth_manager_config.OAUTH_BLOCKED_GROUPS = OAUTH_BLOCKED_GROUPS | 
					
						
							| 
									
										
										
										
											2024-10-16 22:32:57 +08:00
										 |  |  | auth_manager_config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM | 
					
						
							| 
									
										
										
										
											2025-08-09 04:46:14 +08:00
										 |  |  | auth_manager_config.OAUTH_SUB_CLAIM = OAUTH_SUB_CLAIM | 
					
						
							| 
									
										
										
										
											2024-12-18 03:38:07 +08:00
										 |  |  | auth_manager_config.OAUTH_GROUPS_CLAIM = OAUTH_GROUPS_CLAIM | 
					
						
							| 
									
										
										
										
											2024-10-16 22:32:57 +08:00
										 |  |  | auth_manager_config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM | 
					
						
							|  |  |  | auth_manager_config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM | 
					
						
							|  |  |  | auth_manager_config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM | 
					
						
							|  |  |  | auth_manager_config.OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES | 
					
						
							|  |  |  | auth_manager_config.OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES | 
					
						
							| 
									
										
										
										
											2024-12-02 16:36:56 +08:00
										 |  |  | auth_manager_config.OAUTH_ALLOWED_DOMAINS = OAUTH_ALLOWED_DOMAINS | 
					
						
							| 
									
										
										
										
											2024-10-16 22:32:57 +08:00
										 |  |  | auth_manager_config.WEBHOOK_URL = WEBHOOK_URL | 
					
						
							|  |  |  | auth_manager_config.JWT_EXPIRES_IN = JWT_EXPIRES_IN | 
					
						
							| 
									
										
										
										
											2025-05-07 00:00:35 +08:00
										 |  |  | auth_manager_config.OAUTH_UPDATE_PICTURE_ON_LOGIN = OAUTH_UPDATE_PICTURE_ON_LOGIN | 
					
						
							| 
									
										
										
										
											2024-10-16 22:32:57 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-25 13:28:13 +08:00
										 |  |  | FERNET = None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | if len(OAUTH_CLIENT_INFO_ENCRYPTION_KEY) != 44: | 
					
						
							|  |  |  |     key_bytes = hashlib.sha256(OAUTH_CLIENT_INFO_ENCRYPTION_KEY.encode()).digest() | 
					
						
							|  |  |  |     OAUTH_CLIENT_INFO_ENCRYPTION_KEY = base64.urlsafe_b64encode(key_bytes) | 
					
						
							|  |  |  | else: | 
					
						
							|  |  |  |     OAUTH_CLIENT_INFO_ENCRYPTION_KEY = OAUTH_CLIENT_INFO_ENCRYPTION_KEY.encode() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | try: | 
					
						
							|  |  |  |     FERNET = Fernet(OAUTH_CLIENT_INFO_ENCRYPTION_KEY) | 
					
						
							|  |  |  | except Exception as e: | 
					
						
							|  |  |  |     log.error(f"Error initializing Fernet with provided key: {e}") | 
					
						
							|  |  |  |     raise | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-25 14:49:16 +08:00
										 |  |  | def encrypt_data(data) -> str: | 
					
						
							|  |  |  |     """Encrypt data for storage""" | 
					
						
							| 
									
										
										
										
											2025-09-25 13:28:13 +08:00
										 |  |  |     try: | 
					
						
							| 
									
										
										
										
											2025-09-25 14:49:16 +08:00
										 |  |  |         data_json = json.dumps(data) | 
					
						
							|  |  |  |         encrypted = FERNET.encrypt(data_json.encode()).decode() | 
					
						
							| 
									
										
										
										
											2025-09-25 13:28:13 +08:00
										 |  |  |         return encrypted | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							| 
									
										
										
										
											2025-09-25 14:49:16 +08:00
										 |  |  |         log.error(f"Error encrypting data: {e}") | 
					
						
							| 
									
										
										
										
											2025-09-25 13:28:13 +08:00
										 |  |  |         raise | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-25 14:49:16 +08:00
										 |  |  | def decrypt_data(data: str): | 
					
						
							|  |  |  |     """Decrypt data from storage""" | 
					
						
							| 
									
										
										
										
											2025-09-25 13:28:13 +08:00
										 |  |  |     try: | 
					
						
							| 
									
										
										
										
											2025-09-25 14:49:16 +08:00
										 |  |  |         decrypted = FERNET.decrypt(data.encode()).decode() | 
					
						
							| 
									
										
										
										
											2025-09-25 13:28:13 +08:00
										 |  |  |         return json.loads(decrypted) | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							| 
									
										
										
										
											2025-09-25 14:49:16 +08:00
										 |  |  |         log.error(f"Error decrypting data: {e}") | 
					
						
							| 
									
										
										
										
											2025-09-25 13:28:13 +08:00
										 |  |  |         raise | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-03 22:50:02 +08:00
										 |  |  | def is_in_blocked_groups(group_name: str, groups: list) -> bool: | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     Check if a group name matches any blocked pattern. | 
					
						
							|  |  |  |     Supports exact matches, shell-style wildcards (*, ?), and regex patterns. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     Args: | 
					
						
							|  |  |  |         group_name: The group name to check | 
					
						
							|  |  |  |         groups: List of patterns to match against | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     Returns: | 
					
						
							|  |  |  |         True if the group is blocked, False otherwise | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     if not groups: | 
					
						
							|  |  |  |         return False | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     for group_pattern in groups: | 
					
						
							|  |  |  |         if not group_pattern:  # Skip empty patterns | 
					
						
							|  |  |  |             continue | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Exact match | 
					
						
							|  |  |  |         if group_name == group_pattern: | 
					
						
							|  |  |  |             return True | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Try as regex pattern first if it contains regex-specific characters | 
					
						
							|  |  |  |         if any( | 
					
						
							|  |  |  |             char in group_pattern | 
					
						
							|  |  |  |             for char in ["^", "$", "[", "]", "(", ")", "{", "}", "+", "\\", "|"] | 
					
						
							|  |  |  |         ): | 
					
						
							|  |  |  |             try: | 
					
						
							|  |  |  |                 # Use the original pattern as-is for regex matching | 
					
						
							|  |  |  |                 if re.search(group_pattern, group_name): | 
					
						
							|  |  |  |                     return True | 
					
						
							|  |  |  |             except re.error: | 
					
						
							|  |  |  |                 # If regex is invalid, fall through to wildcard check | 
					
						
							|  |  |  |                 pass | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Shell-style wildcard match (supports * and ?) | 
					
						
							|  |  |  |         if "*" in group_pattern or "?" in group_pattern: | 
					
						
							|  |  |  |             if fnmatch.fnmatch(group_name, group_pattern): | 
					
						
							|  |  |  |                 return True | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return False | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-25 13:28:13 +08:00
										 |  |  | def get_parsed_and_base_url(server_url) -> tuple[urllib.parse.ParseResult, str]: | 
					
						
							|  |  |  |     parsed = urllib.parse.urlparse(server_url) | 
					
						
							|  |  |  |     base_url = f"{parsed.scheme}://{parsed.netloc}" | 
					
						
							|  |  |  |     return parsed, base_url | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def get_discovery_urls(server_url) -> list[str]: | 
					
						
							|  |  |  |     parsed, base_url = get_parsed_and_base_url(server_url) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-27 03:34:26 +08:00
										 |  |  |     urls = [ | 
					
						
							|  |  |  |         urllib.parse.urljoin(base_url, "/.well-known/oauth-authorization-server"), | 
					
						
							|  |  |  |         urllib.parse.urljoin(base_url, "/.well-known/openid-configuration"), | 
					
						
							|  |  |  |     ] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if parsed.path and parsed.path != "/": | 
					
						
							|  |  |  |         urls.append( | 
					
						
							|  |  |  |             urllib.parse.urljoin( | 
					
						
							|  |  |  |                 base_url, | 
					
						
							|  |  |  |                 f"/.well-known/oauth-authorization-server{parsed.path.rstrip('/')}", | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         urls.append( | 
					
						
							|  |  |  |             urllib.parse.urljoin( | 
					
						
							|  |  |  |                 base_url, f"/.well-known/openid-configuration{parsed.path.rstrip('/')}" | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2025-09-25 13:28:13 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     return urls | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # TODO: Some OAuth providers require Initial Access Tokens (IATs) for dynamic client registration. | 
					
						
							|  |  |  | # This is not currently supported. | 
					
						
							|  |  |  | async def get_oauth_client_info_with_dynamic_client_registration( | 
					
						
							| 
									
										
										
										
											2025-09-25 14:49:16 +08:00
										 |  |  |     request, | 
					
						
							|  |  |  |     client_id: str, | 
					
						
							|  |  |  |     oauth_server_url: str, | 
					
						
							|  |  |  |     oauth_server_key: Optional[str] = None, | 
					
						
							| 
									
										
										
										
											2025-09-25 13:28:13 +08:00
										 |  |  | ) -> OAuthClientInformationFull: | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         oauth_server_metadata = None | 
					
						
							|  |  |  |         oauth_server_metadata_url = None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         redirect_base_url = ( | 
					
						
							|  |  |  |             str(request.app.state.config.WEBUI_URL or request.base_url) | 
					
						
							|  |  |  |         ).rstrip("/") | 
					
						
							| 
									
										
										
										
											2025-09-25 14:49:16 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-25 13:28:13 +08:00
										 |  |  |         oauth_client_metadata = OAuthClientMetadata( | 
					
						
							|  |  |  |             client_name="Open WebUI", | 
					
						
							| 
									
										
										
										
											2025-09-25 14:49:16 +08:00
										 |  |  |             redirect_uris=[f"{redirect_base_url}/oauth/clients/{client_id}/callback"], | 
					
						
							| 
									
										
										
										
											2025-09-25 13:28:13 +08:00
										 |  |  |             grant_types=["authorization_code", "refresh_token"], | 
					
						
							|  |  |  |             response_types=["code"], | 
					
						
							|  |  |  |             token_endpoint_auth_method="client_secret_post", | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Attempt to fetch OAuth server metadata to get registration endpoint & scopes | 
					
						
							|  |  |  |         discovery_urls = get_discovery_urls(oauth_server_url) | 
					
						
							|  |  |  |         for url in discovery_urls: | 
					
						
							|  |  |  |             async with aiohttp.ClientSession() as session: | 
					
						
							|  |  |  |                 async with session.get( | 
					
						
							|  |  |  |                     url, ssl=AIOHTTP_CLIENT_SESSION_SSL | 
					
						
							|  |  |  |                 ) as oauth_server_metadata_response: | 
					
						
							|  |  |  |                     if oauth_server_metadata_response.status == 200: | 
					
						
							|  |  |  |                         try: | 
					
						
							|  |  |  |                             oauth_server_metadata = OAuthMetadata.model_validate( | 
					
						
							|  |  |  |                                 await oauth_server_metadata_response.json() | 
					
						
							|  |  |  |                             ) | 
					
						
							|  |  |  |                             oauth_server_metadata_url = url | 
					
						
							|  |  |  |                             if ( | 
					
						
							|  |  |  |                                 oauth_client_metadata.scope is None | 
					
						
							|  |  |  |                                 and oauth_server_metadata.scopes_supported is not None | 
					
						
							|  |  |  |                             ): | 
					
						
							|  |  |  |                                 oauth_client_metadata.scope = " ".join( | 
					
						
							|  |  |  |                                     oauth_server_metadata.scopes_supported | 
					
						
							|  |  |  |                                 ) | 
					
						
							|  |  |  |                             break | 
					
						
							|  |  |  |                         except Exception as e: | 
					
						
							|  |  |  |                             log.error(f"Error parsing OAuth metadata from {url}: {e}") | 
					
						
							|  |  |  |                             continue | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         registration_url = None | 
					
						
							|  |  |  |         if oauth_server_metadata and oauth_server_metadata.registration_endpoint: | 
					
						
							|  |  |  |             registration_url = str(oauth_server_metadata.registration_endpoint) | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             _, base_url = get_parsed_and_base_url(oauth_server_url) | 
					
						
							|  |  |  |             registration_url = urllib.parse.urljoin(base_url, "/register") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         registration_data = oauth_client_metadata.model_dump( | 
					
						
							|  |  |  |             exclude_none=True, | 
					
						
							|  |  |  |             mode="json", | 
					
						
							|  |  |  |             by_alias=True, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Perform dynamic client registration and return client info | 
					
						
							|  |  |  |         async with aiohttp.ClientSession() as session: | 
					
						
							|  |  |  |             async with session.post( | 
					
						
							|  |  |  |                 registration_url, json=registration_data, ssl=AIOHTTP_CLIENT_SESSION_SSL | 
					
						
							|  |  |  |             ) as oauth_client_registration_response: | 
					
						
							|  |  |  |                 try: | 
					
						
							|  |  |  |                     registration_response_json = ( | 
					
						
							|  |  |  |                         await oauth_client_registration_response.json() | 
					
						
							|  |  |  |                     ) | 
					
						
							|  |  |  |                     oauth_client_info = OAuthClientInformationFull.model_validate( | 
					
						
							|  |  |  |                         { | 
					
						
							|  |  |  |                             **registration_response_json, | 
					
						
							|  |  |  |                             **{"issuer": oauth_server_metadata_url}, | 
					
						
							|  |  |  |                         } | 
					
						
							|  |  |  |                     ) | 
					
						
							|  |  |  |                     log.info( | 
					
						
							|  |  |  |                         f"Dynamic client registration successful at {registration_url}, client_id: {oauth_client_info.client_id}" | 
					
						
							|  |  |  |                     ) | 
					
						
							|  |  |  |                     return oauth_client_info | 
					
						
							|  |  |  |                 except Exception as e: | 
					
						
							|  |  |  |                     error_text = None | 
					
						
							|  |  |  |                     try: | 
					
						
							|  |  |  |                         error_text = await oauth_client_registration_response.text() | 
					
						
							|  |  |  |                         log.error( | 
					
						
							|  |  |  |                             f"Dynamic client registration failed at {registration_url}: {oauth_client_registration_response.status} - {error_text}" | 
					
						
							|  |  |  |                         ) | 
					
						
							|  |  |  |                     except Exception as e: | 
					
						
							|  |  |  |                         pass | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     log.error(f"Error parsing client registration response: {e}") | 
					
						
							|  |  |  |                     raise Exception( | 
					
						
							|  |  |  |                         f"Dynamic client registration failed: {error_text}" | 
					
						
							|  |  |  |                         if error_text | 
					
						
							|  |  |  |                         else "Error parsing client registration response" | 
					
						
							|  |  |  |                     ) | 
					
						
							|  |  |  |         raise Exception("Dynamic client registration failed") | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							|  |  |  |         log.error(f"Exception during dynamic client registration: {e}") | 
					
						
							|  |  |  |         raise e | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class OAuthClientManager: | 
					
						
							|  |  |  |     def __init__(self, app): | 
					
						
							|  |  |  |         self.oauth = OAuth() | 
					
						
							|  |  |  |         self.app = app | 
					
						
							|  |  |  |         self.clients = {} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def add_client(self, client_id, oauth_client_info: OAuthClientInformationFull): | 
					
						
							| 
									
										
										
										
											2025-09-25 14:49:16 +08:00
										 |  |  |         self.clients[client_id] = { | 
					
						
							|  |  |  |             "client": self.oauth.register( | 
					
						
							|  |  |  |                 name=client_id, | 
					
						
							|  |  |  |                 client_id=oauth_client_info.client_id, | 
					
						
							|  |  |  |                 client_secret=oauth_client_info.client_secret, | 
					
						
							|  |  |  |                 client_kwargs=( | 
					
						
							|  |  |  |                     {"scope": oauth_client_info.scope} | 
					
						
							|  |  |  |                     if oauth_client_info.scope | 
					
						
							|  |  |  |                     else {} | 
					
						
							| 
									
										
										
										
											2025-09-25 13:28:13 +08:00
										 |  |  |                 ), | 
					
						
							| 
									
										
										
										
											2025-09-25 14:49:16 +08:00
										 |  |  |                 server_metadata_url=( | 
					
						
							|  |  |  |                     oauth_client_info.issuer if oauth_client_info.issuer else None | 
					
						
							|  |  |  |                 ), | 
					
						
							|  |  |  |             ), | 
					
						
							|  |  |  |             "client_info": oauth_client_info, | 
					
						
							|  |  |  |         } | 
					
						
							| 
									
										
										
										
											2025-09-25 13:28:13 +08:00
										 |  |  |         return self.clients[client_id] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def remove_client(self, client_id): | 
					
						
							|  |  |  |         if client_id in self.clients: | 
					
						
							|  |  |  |             del self.clients[client_id] | 
					
						
							|  |  |  |             log.info(f"Removed OAuth client {client_id}") | 
					
						
							|  |  |  |         return True | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def get_client(self, client_id): | 
					
						
							|  |  |  |         client = self.clients.get(client_id) | 
					
						
							|  |  |  |         return client["client"] if client else None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def get_client_info(self, client_id): | 
					
						
							|  |  |  |         client = self.clients.get(client_id) | 
					
						
							|  |  |  |         return client["client_info"] if client else None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def get_server_metadata_url(self, client_id): | 
					
						
							|  |  |  |         if client_id in self.clients: | 
					
						
							|  |  |  |             client = self.clients[client_id] | 
					
						
							|  |  |  |             return ( | 
					
						
							|  |  |  |                 client.server_metadata_url | 
					
						
							|  |  |  |                 if hasattr(client, "server_metadata_url") | 
					
						
							|  |  |  |                 else None | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |         return None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     async def get_oauth_token( | 
					
						
							| 
									
										
										
										
											2025-09-25 14:49:16 +08:00
										 |  |  |         self, user_id: str, client_id: str, force_refresh: bool = False | 
					
						
							| 
									
										
										
										
											2025-09-25 13:28:13 +08:00
										 |  |  |     ): | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Get a valid OAuth token for the user, automatically refreshing if needed. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         Args: | 
					
						
							|  |  |  |             user_id: The user ID | 
					
						
							| 
									
										
										
										
											2025-09-25 15:00:02 +08:00
										 |  |  |             client_id: The OAuth client ID (provider) | 
					
						
							| 
									
										
										
										
											2025-09-25 13:28:13 +08:00
										 |  |  |             force_refresh: Force token refresh even if current token appears valid | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         Returns: | 
					
						
							|  |  |  |             dict: OAuth token data with access_token, or None if no valid token available | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             # Get the OAuth session | 
					
						
							| 
									
										
										
										
											2025-09-25 14:49:16 +08:00
										 |  |  |             session = OAuthSessions.get_session_by_provider_and_user_id( | 
					
						
							|  |  |  |                 client_id, user_id | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2025-09-25 13:28:13 +08:00
										 |  |  |             if not session: | 
					
						
							|  |  |  |                 log.warning( | 
					
						
							| 
									
										
										
										
											2025-09-25 14:49:16 +08:00
										 |  |  |                     f"No OAuth session found for user {user_id}, client_id {client_id}" | 
					
						
							| 
									
										
										
										
											2025-09-25 13:28:13 +08:00
										 |  |  |                 ) | 
					
						
							|  |  |  |                 return None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             if force_refresh or datetime.now() + timedelta( | 
					
						
							|  |  |  |                 minutes=5 | 
					
						
							|  |  |  |             ) >= datetime.fromtimestamp(session.expires_at): | 
					
						
							|  |  |  |                 log.debug( | 
					
						
							|  |  |  |                     f"Token refresh needed for user {user_id}, client_id {session.provider}" | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |                 refreshed_token = await self._refresh_token(session) | 
					
						
							|  |  |  |                 if refreshed_token: | 
					
						
							|  |  |  |                     return refreshed_token | 
					
						
							|  |  |  |                 else: | 
					
						
							|  |  |  |                     log.warning( | 
					
						
							| 
									
										
										
										
											2025-09-25 14:49:16 +08:00
										 |  |  |                         f"Token refresh failed for user {user_id}, client_id {session.provider}, deleting session {session.id}" | 
					
						
							| 
									
										
										
										
											2025-09-25 13:28:13 +08:00
										 |  |  |                     ) | 
					
						
							| 
									
										
										
										
											2025-09-25 14:49:16 +08:00
										 |  |  |                     OAuthSessions.delete_session_by_id(session.id) | 
					
						
							| 
									
										
										
										
											2025-09-25 13:28:13 +08:00
										 |  |  |                     return None | 
					
						
							|  |  |  |             return session.token | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							|  |  |  |             log.error(f"Error getting OAuth token for user {user_id}: {e}") | 
					
						
							|  |  |  |             return None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     async def _refresh_token(self, session) -> dict: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Refresh an OAuth token if needed, with concurrency protection. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         Args: | 
					
						
							|  |  |  |             session: The OAuth session object | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         Returns: | 
					
						
							|  |  |  |             dict: Refreshed token data, or None if refresh failed | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             # Perform the actual refresh | 
					
						
							|  |  |  |             refreshed_token = await self._perform_token_refresh(session) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             if refreshed_token: | 
					
						
							|  |  |  |                 # Update the session with new token data | 
					
						
							|  |  |  |                 session = OAuthSessions.update_session_by_id( | 
					
						
							|  |  |  |                     session.id, refreshed_token | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |                 log.info(f"Successfully refreshed token for session {session.id}") | 
					
						
							|  |  |  |                 return session.token | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 log.error(f"Failed to refresh token for session {session.id}") | 
					
						
							|  |  |  |                 return None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							|  |  |  |             log.error(f"Error refreshing token for session {session.id}: {e}") | 
					
						
							|  |  |  |             return None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     async def _perform_token_refresh(self, session) -> dict: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Perform the actual OAuth token refresh. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         Args: | 
					
						
							|  |  |  |             session: The OAuth session object | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         Returns: | 
					
						
							|  |  |  |             dict: New token data, or None if refresh failed | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         client_id = session.provider | 
					
						
							|  |  |  |         token_data = session.token | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if not token_data.get("refresh_token"): | 
					
						
							|  |  |  |             log.warning(f"No refresh token available for session {session.id}") | 
					
						
							|  |  |  |             return None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             client = self.get_client(client_id) | 
					
						
							|  |  |  |             if not client: | 
					
						
							|  |  |  |                 log.error(f"No OAuth client found for provider {client_id}") | 
					
						
							|  |  |  |                 return None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             token_endpoint = None | 
					
						
							|  |  |  |             async with aiohttp.ClientSession(trust_env=True) as session_http: | 
					
						
							|  |  |  |                 async with session_http.get( | 
					
						
							|  |  |  |                     self.get_server_metadata_url(client_id) | 
					
						
							|  |  |  |                 ) as r: | 
					
						
							|  |  |  |                     if r.status == 200: | 
					
						
							|  |  |  |                         openid_data = await r.json() | 
					
						
							|  |  |  |                         token_endpoint = openid_data.get("token_endpoint") | 
					
						
							|  |  |  |                     else: | 
					
						
							|  |  |  |                         log.error( | 
					
						
							|  |  |  |                             f"Failed to fetch OpenID configuration for client_id {client_id}" | 
					
						
							|  |  |  |                         ) | 
					
						
							|  |  |  |             if not token_endpoint: | 
					
						
							|  |  |  |                 log.error(f"No token endpoint found for client_id {client_id}") | 
					
						
							|  |  |  |                 return None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # Prepare refresh request | 
					
						
							|  |  |  |             refresh_data = { | 
					
						
							|  |  |  |                 "grant_type": "refresh_token", | 
					
						
							|  |  |  |                 "refresh_token": token_data["refresh_token"], | 
					
						
							|  |  |  |                 "client_id": client.client_id, | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             if hasattr(client, "client_secret") and client.client_secret: | 
					
						
							|  |  |  |                 refresh_data["client_secret"] = client.client_secret | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # Make refresh request | 
					
						
							|  |  |  |             async with aiohttp.ClientSession(trust_env=True) as session_http: | 
					
						
							|  |  |  |                 async with session_http.post( | 
					
						
							|  |  |  |                     token_endpoint, | 
					
						
							|  |  |  |                     data=refresh_data, | 
					
						
							|  |  |  |                     headers={"Content-Type": "application/x-www-form-urlencoded"}, | 
					
						
							|  |  |  |                     ssl=AIOHTTP_CLIENT_SESSION_SSL, | 
					
						
							|  |  |  |                 ) as r: | 
					
						
							|  |  |  |                     if r.status == 200: | 
					
						
							|  |  |  |                         new_token_data = await r.json() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                         # Merge with existing token data (preserve refresh_token if not provided) | 
					
						
							|  |  |  |                         if "refresh_token" not in new_token_data: | 
					
						
							|  |  |  |                             new_token_data["refresh_token"] = token_data[ | 
					
						
							|  |  |  |                                 "refresh_token" | 
					
						
							|  |  |  |                             ] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                         # Add timestamp for tracking | 
					
						
							|  |  |  |                         new_token_data["issued_at"] = datetime.now().timestamp() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                         # Calculate expires_at if we have expires_in | 
					
						
							|  |  |  |                         if ( | 
					
						
							|  |  |  |                             "expires_in" in new_token_data | 
					
						
							|  |  |  |                             and "expires_at" not in new_token_data | 
					
						
							|  |  |  |                         ): | 
					
						
							|  |  |  |                             new_token_data["expires_at"] = int( | 
					
						
							|  |  |  |                                 datetime.now().timestamp() | 
					
						
							|  |  |  |                                 + new_token_data["expires_in"] | 
					
						
							|  |  |  |                             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                         log.debug(f"Token refresh successful for client_id {client_id}") | 
					
						
							|  |  |  |                         return new_token_data | 
					
						
							|  |  |  |                     else: | 
					
						
							|  |  |  |                         error_text = await r.text() | 
					
						
							|  |  |  |                         log.error( | 
					
						
							|  |  |  |                             f"Token refresh failed for client_id {client_id}: {r.status} - {error_text}" | 
					
						
							|  |  |  |                         ) | 
					
						
							|  |  |  |                         return None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							|  |  |  |             log.error(f"Exception during token refresh for client_id {client_id}: {e}") | 
					
						
							|  |  |  |             return None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     async def handle_authorize(self, request, client_id: str) -> RedirectResponse: | 
					
						
							|  |  |  |         client = self.get_client(client_id) | 
					
						
							|  |  |  |         if client is None: | 
					
						
							|  |  |  |             raise HTTPException(404) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         client_info = self.get_client_info(client_id) | 
					
						
							|  |  |  |         if client_info is None: | 
					
						
							|  |  |  |             raise HTTPException(404) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         redirect_uri = ( | 
					
						
							|  |  |  |             client_info.redirect_uris[0] if client_info.redirect_uris else None | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2025-09-25 14:49:16 +08:00
										 |  |  |         return await client.authorize_redirect(request, str(redirect_uri)) | 
					
						
							| 
									
										
										
										
											2025-09-25 13:28:13 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     async def handle_callback(self, request, client_id: str, user_id: str, response): | 
					
						
							|  |  |  |         client = self.get_client(client_id) | 
					
						
							|  |  |  |         if client is None: | 
					
						
							|  |  |  |             raise HTTPException(404) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         error_message = None | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             token = await client.authorize_access_token(request) | 
					
						
							|  |  |  |             if token: | 
					
						
							|  |  |  |                 try: | 
					
						
							|  |  |  |                     # Add timestamp for tracking | 
					
						
							|  |  |  |                     token["issued_at"] = datetime.now().timestamp() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     # Calculate expires_at if we have expires_in | 
					
						
							|  |  |  |                     if "expires_in" in token and "expires_at" not in token: | 
					
						
							|  |  |  |                         token["expires_at"] = ( | 
					
						
							|  |  |  |                             datetime.now().timestamp() + token["expires_in"] | 
					
						
							|  |  |  |                         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     # Clean up any existing sessions for this user/client_id first | 
					
						
							|  |  |  |                     sessions = OAuthSessions.get_sessions_by_user_id(user_id) | 
					
						
							|  |  |  |                     for session in sessions: | 
					
						
							|  |  |  |                         if session.provider == client_id: | 
					
						
							|  |  |  |                             OAuthSessions.delete_session_by_id(session.id) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     session = OAuthSessions.create_session( | 
					
						
							|  |  |  |                         user_id=user_id, | 
					
						
							|  |  |  |                         provider=client_id, | 
					
						
							|  |  |  |                         token=token, | 
					
						
							|  |  |  |                     ) | 
					
						
							|  |  |  |                     log.info( | 
					
						
							|  |  |  |                         f"Stored OAuth session server-side for user {user_id}, client_id {client_id}" | 
					
						
							|  |  |  |                     ) | 
					
						
							|  |  |  |                 except Exception as e: | 
					
						
							|  |  |  |                     error_message = "Failed to store OAuth session server-side" | 
					
						
							|  |  |  |                     log.error(f"Failed to store OAuth session server-side: {e}") | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 error_message = "Failed to obtain OAuth token" | 
					
						
							|  |  |  |                 log.warning(error_message) | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							|  |  |  |             error_message = "OAuth callback error" | 
					
						
							|  |  |  |             log.warning(f"OAuth callback error: {e}") | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-25 14:49:16 +08:00
										 |  |  |         redirect_url = ( | 
					
						
							| 
									
										
										
										
											2025-09-25 13:28:13 +08:00
										 |  |  |             str(request.app.state.config.WEBUI_URL or request.base_url) | 
					
						
							|  |  |  |         ).rstrip("/") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if error_message: | 
					
						
							| 
									
										
										
										
											2025-09-25 14:49:16 +08:00
										 |  |  |             log.debug(error_message) | 
					
						
							|  |  |  |             redirect_url = f"{redirect_url}/?error={error_message}" | 
					
						
							| 
									
										
										
										
											2025-09-25 13:28:13 +08:00
										 |  |  |             return RedirectResponse(url=redirect_url, headers=response.headers) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         response = RedirectResponse(url=redirect_url, headers=response.headers) | 
					
						
							| 
									
										
										
										
											2025-09-25 14:49:16 +08:00
										 |  |  |         return response | 
					
						
							| 
									
										
										
										
											2025-09-25 13:28:13 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-16 22:32:57 +08:00
										 |  |  | class OAuthManager: | 
					
						
							| 
									
										
										
										
											2025-02-16 16:11:18 +08:00
										 |  |  |     def __init__(self, app): | 
					
						
							| 
									
										
										
										
											2024-10-16 22:32:57 +08:00
										 |  |  |         self.oauth = OAuth() | 
					
						
							| 
									
										
										
										
											2025-02-16 16:11:18 +08:00
										 |  |  |         self.app = app | 
					
						
							| 
									
										
										
										
											2025-09-08 22:05:43 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         self._clients = {} | 
					
						
							| 
									
										
										
										
											2025-01-17 12:56:03 +08:00
										 |  |  |         for _, provider_config in OAUTH_PROVIDERS.items(): | 
					
						
							|  |  |  |             provider_config["register"](self.oauth) | 
					
						
							| 
									
										
										
										
											2024-10-16 15:42:47 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-16 22:32:57 +08:00
										 |  |  |     def get_client(self, provider_name): | 
					
						
							| 
									
										
										
										
											2025-09-08 22:05:43 +08:00
										 |  |  |         if provider_name not in self._clients: | 
					
						
							|  |  |  |             self._clients[provider_name] = self.oauth.create_client(provider_name) | 
					
						
							|  |  |  |         return self._clients[provider_name] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def get_server_metadata_url(self, provider_name): | 
					
						
							|  |  |  |         if provider_name in self._clients: | 
					
						
							|  |  |  |             client = self._clients[provider_name] | 
					
						
							|  |  |  |             return ( | 
					
						
							|  |  |  |                 client.server_metadata_url | 
					
						
							|  |  |  |                 if hasattr(client, "server_metadata_url") | 
					
						
							|  |  |  |                 else None | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |         return None | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-19 13:10:48 +08:00
										 |  |  |     async def get_oauth_token( | 
					
						
							| 
									
										
										
										
											2025-09-08 22:05:43 +08:00
										 |  |  |         self, user_id: str, session_id: str, force_refresh: bool = False | 
					
						
							|  |  |  |     ): | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Get a valid OAuth token for the user, automatically refreshing if needed. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         Args: | 
					
						
							|  |  |  |             user_id: The user ID | 
					
						
							|  |  |  |             provider: Optional provider name. If None, gets the most recent session. | 
					
						
							|  |  |  |             force_refresh: Force token refresh even if current token appears valid | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         Returns: | 
					
						
							|  |  |  |             dict: OAuth token data with access_token, or None if no valid token available | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             # Get the OAuth session | 
					
						
							|  |  |  |             session = OAuthSessions.get_session_by_id_and_user_id(session_id, user_id) | 
					
						
							|  |  |  |             if not session: | 
					
						
							|  |  |  |                 log.warning( | 
					
						
							|  |  |  |                     f"No OAuth session found for user {user_id}, session {session_id}" | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |                 return None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             if force_refresh or datetime.now() + timedelta( | 
					
						
							|  |  |  |                 minutes=5 | 
					
						
							|  |  |  |             ) >= datetime.fromtimestamp(session.expires_at): | 
					
						
							|  |  |  |                 log.debug( | 
					
						
							|  |  |  |                     f"Token refresh needed for user {user_id}, provider {session.provider}" | 
					
						
							|  |  |  |                 ) | 
					
						
							| 
									
										
										
										
											2025-09-19 13:10:48 +08:00
										 |  |  |                 refreshed_token = await self._refresh_token(session) | 
					
						
							| 
									
										
										
										
											2025-09-08 22:05:43 +08:00
										 |  |  |                 if refreshed_token: | 
					
						
							|  |  |  |                     return refreshed_token | 
					
						
							|  |  |  |                 else: | 
					
						
							|  |  |  |                     log.warning( | 
					
						
							| 
									
										
										
										
											2025-09-25 14:49:16 +08:00
										 |  |  |                         f"Token refresh failed for user {user_id}, provider {session.provider}, deleting session {session.id}" | 
					
						
							| 
									
										
										
										
											2025-09-08 22:05:43 +08:00
										 |  |  |                     ) | 
					
						
							| 
									
										
										
										
											2025-09-25 14:49:16 +08:00
										 |  |  |                     OAuthSessions.delete_session_by_id(session.id) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-08 22:05:43 +08:00
										 |  |  |                     return None | 
					
						
							|  |  |  |             return session.token | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							|  |  |  |             log.error(f"Error getting OAuth token for user {user_id}: {e}") | 
					
						
							|  |  |  |             return None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     async def _refresh_token(self, session) -> dict: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Refresh an OAuth token if needed, with concurrency protection. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         Args: | 
					
						
							|  |  |  |             session: The OAuth session object | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         Returns: | 
					
						
							|  |  |  |             dict: Refreshed token data, or None if refresh failed | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             # Perform the actual refresh | 
					
						
							|  |  |  |             refreshed_token = await self._perform_token_refresh(session) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             if refreshed_token: | 
					
						
							|  |  |  |                 # Update the session with new token data | 
					
						
							|  |  |  |                 session = OAuthSessions.update_session_by_id( | 
					
						
							|  |  |  |                     session.id, refreshed_token | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |                 log.info(f"Successfully refreshed token for session {session.id}") | 
					
						
							|  |  |  |                 return session.token | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 log.error(f"Failed to refresh token for session {session.id}") | 
					
						
							|  |  |  |                 return None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							|  |  |  |             log.error(f"Error refreshing token for session {session.id}: {e}") | 
					
						
							|  |  |  |             return None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     async def _perform_token_refresh(self, session) -> dict: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Perform the actual OAuth token refresh. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         Args: | 
					
						
							|  |  |  |             session: The OAuth session object | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         Returns: | 
					
						
							|  |  |  |             dict: New token data, or None if refresh failed | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         provider = session.provider | 
					
						
							|  |  |  |         token_data = session.token | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if not token_data.get("refresh_token"): | 
					
						
							|  |  |  |             log.warning(f"No refresh token available for session {session.id}") | 
					
						
							|  |  |  |             return None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             client = self.get_client(provider) | 
					
						
							|  |  |  |             if not client: | 
					
						
							|  |  |  |                 log.error(f"No OAuth client found for provider {provider}") | 
					
						
							|  |  |  |                 return None | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-24 19:56:24 +08:00
										 |  |  |             server_metadata_url = self.get_server_metadata_url(provider) | 
					
						
							| 
									
										
										
										
											2025-09-08 22:05:43 +08:00
										 |  |  |             token_endpoint = None | 
					
						
							|  |  |  |             async with aiohttp.ClientSession(trust_env=True) as session_http: | 
					
						
							| 
									
										
										
										
											2025-09-24 19:56:24 +08:00
										 |  |  |                 async with session_http.get(server_metadata_url) as r: | 
					
						
							| 
									
										
										
										
											2025-09-08 22:05:43 +08:00
										 |  |  |                     if r.status == 200: | 
					
						
							|  |  |  |                         openid_data = await r.json() | 
					
						
							|  |  |  |                         token_endpoint = openid_data.get("token_endpoint") | 
					
						
							|  |  |  |                     else: | 
					
						
							|  |  |  |                         log.error( | 
					
						
							|  |  |  |                             f"Failed to fetch OpenID configuration for provider {provider}" | 
					
						
							|  |  |  |                         ) | 
					
						
							|  |  |  |             if not token_endpoint: | 
					
						
							|  |  |  |                 log.error(f"No token endpoint found for provider {provider}") | 
					
						
							|  |  |  |                 return None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # Prepare refresh request | 
					
						
							|  |  |  |             refresh_data = { | 
					
						
							|  |  |  |                 "grant_type": "refresh_token", | 
					
						
							|  |  |  |                 "refresh_token": token_data["refresh_token"], | 
					
						
							|  |  |  |                 "client_id": client.client_id, | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             # Add client_secret if available (some providers require it) | 
					
						
							|  |  |  |             if hasattr(client, "client_secret") and client.client_secret: | 
					
						
							|  |  |  |                 refresh_data["client_secret"] = client.client_secret | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # Make refresh request | 
					
						
							|  |  |  |             async with aiohttp.ClientSession(trust_env=True) as session_http: | 
					
						
							|  |  |  |                 async with session_http.post( | 
					
						
							|  |  |  |                     token_endpoint, | 
					
						
							|  |  |  |                     data=refresh_data, | 
					
						
							|  |  |  |                     headers={"Content-Type": "application/x-www-form-urlencoded"}, | 
					
						
							|  |  |  |                     ssl=AIOHTTP_CLIENT_SESSION_SSL, | 
					
						
							|  |  |  |                 ) as r: | 
					
						
							|  |  |  |                     if r.status == 200: | 
					
						
							|  |  |  |                         new_token_data = await r.json() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                         # Merge with existing token data (preserve refresh_token if not provided) | 
					
						
							|  |  |  |                         if "refresh_token" not in new_token_data: | 
					
						
							|  |  |  |                             new_token_data["refresh_token"] = token_data[ | 
					
						
							|  |  |  |                                 "refresh_token" | 
					
						
							|  |  |  |                             ] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                         # Add timestamp for tracking | 
					
						
							|  |  |  |                         new_token_data["issued_at"] = datetime.now().timestamp() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                         # Calculate expires_at if we have expires_in | 
					
						
							|  |  |  |                         if ( | 
					
						
							|  |  |  |                             "expires_in" in new_token_data | 
					
						
							|  |  |  |                             and "expires_at" not in new_token_data | 
					
						
							|  |  |  |                         ): | 
					
						
							| 
									
										
										
										
											2025-09-24 19:56:50 +08:00
										 |  |  |                             new_token_data["expires_at"] = int( | 
					
						
							| 
									
										
										
										
											2025-09-08 22:05:43 +08:00
										 |  |  |                                 datetime.now().timestamp() | 
					
						
							|  |  |  |                                 + new_token_data["expires_in"] | 
					
						
							|  |  |  |                             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                         log.debug(f"Token refresh successful for provider {provider}") | 
					
						
							|  |  |  |                         return new_token_data | 
					
						
							|  |  |  |                     else: | 
					
						
							|  |  |  |                         error_text = await r.text() | 
					
						
							|  |  |  |                         log.error( | 
					
						
							|  |  |  |                             f"Token refresh failed for provider {provider}: {r.status} - {error_text}" | 
					
						
							|  |  |  |                         ) | 
					
						
							|  |  |  |                         return None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							|  |  |  |             log.error(f"Exception during token refresh for provider {provider}: {e}") | 
					
						
							|  |  |  |             return None | 
					
						
							| 
									
										
										
										
											2024-10-16 22:32:57 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def get_user_role(self, user, user_data): | 
					
						
							| 
									
										
										
										
											2025-08-06 02:15:22 +08:00
										 |  |  |         user_count = Users.get_num_users() | 
					
						
							|  |  |  |         if user and user_count == 1: | 
					
						
							| 
									
										
										
										
											2024-10-16 22:32:57 +08:00
										 |  |  |             # If the user is the only user, assign the role "admin" - actually repairs role for single user on login | 
					
						
							| 
									
										
										
										
											2025-02-08 03:53:25 +08:00
										 |  |  |             log.debug("Assigning the only user the admin role") | 
					
						
							| 
									
										
										
										
											2024-10-16 22:32:57 +08:00
										 |  |  |             return "admin" | 
					
						
							| 
									
										
										
										
											2025-08-06 02:15:22 +08:00
										 |  |  |         if not user and user_count == 0: | 
					
						
							| 
									
										
										
										
											2024-10-16 22:32:57 +08:00
										 |  |  |             # If there are no users, assign the role "admin", as the first user will be an admin | 
					
						
							| 
									
										
										
										
											2025-02-08 03:53:25 +08:00
										 |  |  |             log.debug("Assigning the first user the admin role") | 
					
						
							| 
									
										
										
										
											2024-10-16 22:32:57 +08:00
										 |  |  |             return "admin" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if auth_manager_config.ENABLE_OAUTH_ROLE_MANAGEMENT: | 
					
						
							| 
									
										
										
										
											2025-02-08 03:53:25 +08:00
										 |  |  |             log.debug("Running OAUTH Role management") | 
					
						
							| 
									
										
										
										
											2024-10-16 22:32:57 +08:00
										 |  |  |             oauth_claim = auth_manager_config.OAUTH_ROLES_CLAIM | 
					
						
							|  |  |  |             oauth_allowed_roles = auth_manager_config.OAUTH_ALLOWED_ROLES | 
					
						
							|  |  |  |             oauth_admin_roles = auth_manager_config.OAUTH_ADMIN_ROLES | 
					
						
							| 
									
										
										
										
											2025-03-10 17:42:59 +08:00
										 |  |  |             oauth_roles = [] | 
					
						
							| 
									
										
										
										
											2025-01-31 22:05:33 +08:00
										 |  |  |             # Default/fallback role if no matching roles are found | 
					
						
							|  |  |  |             role = auth_manager_config.DEFAULT_USER_ROLE | 
					
						
							| 
									
										
										
										
											2024-10-16 22:32:57 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |             # Next block extracts the roles from the user data, accepting nested claims of any depth | 
					
						
							|  |  |  |             if oauth_claim and oauth_allowed_roles and oauth_admin_roles: | 
					
						
							|  |  |  |                 claim_data = user_data | 
					
						
							|  |  |  |                 nested_claims = oauth_claim.split(".") | 
					
						
							|  |  |  |                 for nested_claim in nested_claims: | 
					
						
							|  |  |  |                     claim_data = claim_data.get(nested_claim, {}) | 
					
						
							| 
									
										
										
										
											2025-08-18 23:49:29 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |                 oauth_roles = [] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 if isinstance(claim_data, list): | 
					
						
							|  |  |  |                     oauth_roles = claim_data | 
					
						
							|  |  |  |                 if isinstance(claim_data, str) or isinstance(claim_data, int): | 
					
						
							|  |  |  |                     oauth_roles = [str(claim_data)] | 
					
						
							| 
									
										
										
										
											2024-10-16 22:32:57 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-08 03:53:25 +08:00
										 |  |  |             log.debug(f"Oauth Roles claim: {oauth_claim}") | 
					
						
							|  |  |  |             log.debug(f"User roles from oauth: {oauth_roles}") | 
					
						
							|  |  |  |             log.debug(f"Accepted user roles: {oauth_allowed_roles}") | 
					
						
							|  |  |  |             log.debug(f"Accepted admin roles: {oauth_admin_roles}") | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-16 22:32:57 +08:00
										 |  |  |             # If any roles are found, check if they match the allowed or admin roles | 
					
						
							|  |  |  |             if oauth_roles: | 
					
						
							|  |  |  |                 # If role management is enabled, and matching roles are provided, use the roles | 
					
						
							|  |  |  |                 for allowed_role in oauth_allowed_roles: | 
					
						
							|  |  |  |                     # If the user has any of the allowed roles, assign the role "user" | 
					
						
							|  |  |  |                     if allowed_role in oauth_roles: | 
					
						
							| 
									
										
										
										
											2025-02-08 03:53:25 +08:00
										 |  |  |                         log.debug("Assigned user the user role") | 
					
						
							| 
									
										
										
										
											2024-10-16 22:32:57 +08:00
										 |  |  |                         role = "user" | 
					
						
							|  |  |  |                         break | 
					
						
							|  |  |  |                 for admin_role in oauth_admin_roles: | 
					
						
							|  |  |  |                     # If the user has any of the admin roles, assign the role "admin" | 
					
						
							|  |  |  |                     if admin_role in oauth_roles: | 
					
						
							| 
									
										
										
										
											2025-02-08 03:53:25 +08:00
										 |  |  |                         log.debug("Assigned user the admin role") | 
					
						
							| 
									
										
										
										
											2024-10-16 22:32:57 +08:00
										 |  |  |                         role = "admin" | 
					
						
							|  |  |  |                         break | 
					
						
							| 
									
										
										
										
											2024-10-16 15:42:47 +08:00
										 |  |  |         else: | 
					
						
							| 
									
										
										
										
											2024-10-16 22:32:57 +08:00
										 |  |  |             if not user: | 
					
						
							|  |  |  |                 # If role management is disabled, use the default role for new users | 
					
						
							|  |  |  |                 role = auth_manager_config.DEFAULT_USER_ROLE | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 # If role management is disabled, use the existing role for existing users | 
					
						
							|  |  |  |                 role = user.role | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return role | 
					
						
							| 
									
										
										
										
											2024-12-18 05:51:29 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-18 03:38:07 +08:00
										 |  |  |     def update_user_groups(self, user, user_data, default_permissions): | 
					
						
							| 
									
										
										
										
											2025-02-08 03:53:25 +08:00
										 |  |  |         log.debug("Running OAUTH Group management") | 
					
						
							| 
									
										
										
										
											2024-12-18 03:38:07 +08:00
										 |  |  |         oauth_claim = auth_manager_config.OAUTH_GROUPS_CLAIM | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-05-02 18:47:02 +08:00
										 |  |  |         try: | 
					
						
							|  |  |  |             blocked_groups = json.loads(auth_manager_config.OAUTH_BLOCKED_GROUPS) | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							|  |  |  |             log.exception(f"Error loading OAUTH_BLOCKED_GROUPS: {e}") | 
					
						
							|  |  |  |             blocked_groups = [] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-03-10 17:42:59 +08:00
										 |  |  |         user_oauth_groups = [] | 
					
						
							| 
									
										
										
										
											2025-02-20 00:47:52 +08:00
										 |  |  |         # Nested claim search for groups claim | 
					
						
							|  |  |  |         if oauth_claim: | 
					
						
							|  |  |  |             claim_data = user_data | 
					
						
							|  |  |  |             nested_claims = oauth_claim.split(".") | 
					
						
							|  |  |  |             for nested_claim in nested_claims: | 
					
						
							|  |  |  |                 claim_data = claim_data.get(nested_claim, {}) | 
					
						
							| 
									
										
										
										
											2025-05-07 06:01:03 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-05-06 03:38:31 +08:00
										 |  |  |             if isinstance(claim_data, list): | 
					
						
							|  |  |  |                 user_oauth_groups = claim_data | 
					
						
							|  |  |  |             elif isinstance(claim_data, str): | 
					
						
							|  |  |  |                 user_oauth_groups = [claim_data] | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 user_oauth_groups = [] | 
					
						
							| 
									
										
										
										
											2025-02-20 00:47:52 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-18 03:38:07 +08:00
										 |  |  |         user_current_groups: list[GroupModel] = Groups.get_groups_by_member_id(user.id) | 
					
						
							|  |  |  |         all_available_groups: list[GroupModel] = Groups.get_groups() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-04-19 01:17:08 +08:00
										 |  |  |         # Create groups if they don't exist and creation is enabled | 
					
						
							|  |  |  |         if auth_manager_config.ENABLE_OAUTH_GROUP_CREATION: | 
					
						
							|  |  |  |             log.debug("Checking for missing groups to create...") | 
					
						
							|  |  |  |             all_group_names = {g.name for g in all_available_groups} | 
					
						
							|  |  |  |             groups_created = False | 
					
						
							|  |  |  |             # Determine creator ID: Prefer admin, fallback to current user if no admin exists | 
					
						
							| 
									
										
										
										
											2025-05-05 23:38:36 +08:00
										 |  |  |             admin_user = Users.get_super_admin_user() | 
					
						
							| 
									
										
										
										
											2025-04-19 01:17:08 +08:00
										 |  |  |             creator_id = admin_user.id if admin_user else user.id | 
					
						
							|  |  |  |             log.debug(f"Using creator ID {creator_id} for potential group creation.") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             for group_name in user_oauth_groups: | 
					
						
							|  |  |  |                 if group_name not in all_group_names: | 
					
						
							| 
									
										
										
										
											2025-04-23 15:05:15 +08:00
										 |  |  |                     log.info( | 
					
						
							|  |  |  |                         f"Group '{group_name}' not found via OAuth claim. Creating group..." | 
					
						
							|  |  |  |                     ) | 
					
						
							| 
									
										
										
										
											2025-04-19 01:17:08 +08:00
										 |  |  |                     try: | 
					
						
							|  |  |  |                         new_group_form = GroupForm( | 
					
						
							|  |  |  |                             name=group_name, | 
					
						
							|  |  |  |                             description=f"Group '{group_name}' created automatically via OAuth.", | 
					
						
							| 
									
										
										
										
											2025-04-23 15:05:15 +08:00
										 |  |  |                             permissions=default_permissions,  # Use default permissions from function args | 
					
						
							|  |  |  |                             user_ids=[],  # Start with no users, user will be added later by subsequent logic | 
					
						
							| 
									
										
										
										
											2025-04-19 01:17:08 +08:00
										 |  |  |                         ) | 
					
						
							|  |  |  |                         # Use determined creator ID (admin or fallback to current user) | 
					
						
							| 
									
										
										
										
											2025-04-23 15:05:15 +08:00
										 |  |  |                         created_group = Groups.insert_new_group( | 
					
						
							|  |  |  |                             creator_id, new_group_form | 
					
						
							|  |  |  |                         ) | 
					
						
							| 
									
										
										
										
											2025-04-19 01:17:08 +08:00
										 |  |  |                         if created_group: | 
					
						
							| 
									
										
										
										
											2025-04-23 15:05:15 +08:00
										 |  |  |                             log.info( | 
					
						
							|  |  |  |                                 f"Successfully created group '{group_name}' with ID {created_group.id} using creator ID {creator_id}" | 
					
						
							|  |  |  |                             ) | 
					
						
							| 
									
										
										
										
											2025-04-19 01:17:08 +08:00
										 |  |  |                             groups_created = True | 
					
						
							|  |  |  |                             # Add to local set to prevent duplicate creation attempts in this run | 
					
						
							|  |  |  |                             all_group_names.add(group_name) | 
					
						
							|  |  |  |                         else: | 
					
						
							| 
									
										
										
										
											2025-04-23 15:05:15 +08:00
										 |  |  |                             log.error( | 
					
						
							|  |  |  |                                 f"Failed to create group '{group_name}' via OAuth." | 
					
						
							|  |  |  |                             ) | 
					
						
							| 
									
										
										
										
											2025-04-19 01:17:08 +08:00
										 |  |  |                     except Exception as e: | 
					
						
							|  |  |  |                         log.error(f"Error creating group '{group_name}' via OAuth: {e}") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # Refresh the list of all available groups if any were created | 
					
						
							|  |  |  |             if groups_created: | 
					
						
							|  |  |  |                 all_available_groups = Groups.get_groups() | 
					
						
							|  |  |  |                 log.debug("Refreshed list of all available groups after creation.") | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-08 03:53:25 +08:00
										 |  |  |         log.debug(f"Oauth Groups claim: {oauth_claim}") | 
					
						
							|  |  |  |         log.debug(f"User oauth groups: {user_oauth_groups}") | 
					
						
							|  |  |  |         log.debug(f"User's current groups: {[g.name for g in user_current_groups]}") | 
					
						
							| 
									
										
										
										
											2025-02-10 14:20:47 +08:00
										 |  |  |         log.debug( | 
					
						
							|  |  |  |             f"All groups available in OpenWebUI: {[g.name for g in all_available_groups]}" | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2025-02-08 03:53:25 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-18 03:38:07 +08:00
										 |  |  |         # Remove groups that user is no longer a part of | 
					
						
							|  |  |  |         for group_model in user_current_groups: | 
					
						
							| 
									
										
										
										
											2025-05-02 18:47:02 +08:00
										 |  |  |             if ( | 
					
						
							|  |  |  |                 user_oauth_groups | 
					
						
							|  |  |  |                 and group_model.name not in user_oauth_groups | 
					
						
							| 
									
										
										
										
											2025-09-03 22:50:02 +08:00
										 |  |  |                 and not is_in_blocked_groups(group_model.name, blocked_groups) | 
					
						
							| 
									
										
										
										
											2025-05-02 18:47:02 +08:00
										 |  |  |             ): | 
					
						
							| 
									
										
										
										
											2024-12-18 03:38:07 +08:00
										 |  |  |                 # Remove group from user | 
					
						
							| 
									
										
										
										
											2025-02-10 14:20:47 +08:00
										 |  |  |                 log.debug( | 
					
						
							|  |  |  |                     f"Removing user from group {group_model.name} as it is no longer in their oauth groups" | 
					
						
							|  |  |  |                 ) | 
					
						
							| 
									
										
										
										
											2024-12-18 03:38:07 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |                 user_ids = group_model.user_ids | 
					
						
							|  |  |  |                 user_ids = [i for i in user_ids if i != user.id] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 # In case a group is created, but perms are never assigned to the group by hitting "save" | 
					
						
							|  |  |  |                 group_permissions = group_model.permissions | 
					
						
							|  |  |  |                 if not group_permissions: | 
					
						
							|  |  |  |                     group_permissions = default_permissions | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-18 05:51:29 +08:00
										 |  |  |                 update_form = GroupUpdateForm( | 
					
						
							|  |  |  |                     name=group_model.name, | 
					
						
							|  |  |  |                     description=group_model.description, | 
					
						
							|  |  |  |                     permissions=group_permissions, | 
					
						
							|  |  |  |                     user_ids=user_ids, | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |                 Groups.update_group_by_id( | 
					
						
							|  |  |  |                     id=group_model.id, form_data=update_form, overwrite=False | 
					
						
							|  |  |  |                 ) | 
					
						
							| 
									
										
										
										
											2024-12-18 03:38:07 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         # Add user to new groups | 
					
						
							|  |  |  |         for group_model in all_available_groups: | 
					
						
							| 
									
										
										
										
											2025-03-10 17:42:59 +08:00
										 |  |  |             if ( | 
					
						
							|  |  |  |                 user_oauth_groups | 
					
						
							|  |  |  |                 and group_model.name in user_oauth_groups | 
					
						
							|  |  |  |                 and not any(gm.name == group_model.name for gm in user_current_groups) | 
					
						
							| 
									
										
										
										
											2025-09-03 22:50:02 +08:00
										 |  |  |                 and not is_in_blocked_groups(group_model.name, blocked_groups) | 
					
						
							| 
									
										
										
										
											2024-12-18 05:51:29 +08:00
										 |  |  |             ): | 
					
						
							| 
									
										
										
										
											2024-12-18 03:38:07 +08:00
										 |  |  |                 # Add user to group | 
					
						
							| 
									
										
										
										
											2025-02-10 14:20:47 +08:00
										 |  |  |                 log.debug( | 
					
						
							|  |  |  |                     f"Adding user to group {group_model.name} as it was found in their oauth groups" | 
					
						
							|  |  |  |                 ) | 
					
						
							| 
									
										
										
										
											2024-12-18 03:38:07 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |                 user_ids = group_model.user_ids | 
					
						
							|  |  |  |                 user_ids.append(user.id) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 # In case a group is created, but perms are never assigned to the group by hitting "save" | 
					
						
							|  |  |  |                 group_permissions = group_model.permissions | 
					
						
							|  |  |  |                 if not group_permissions: | 
					
						
							|  |  |  |                     group_permissions = default_permissions | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-18 05:51:29 +08:00
										 |  |  |                 update_form = GroupUpdateForm( | 
					
						
							|  |  |  |                     name=group_model.name, | 
					
						
							|  |  |  |                     description=group_model.description, | 
					
						
							|  |  |  |                     permissions=group_permissions, | 
					
						
							|  |  |  |                     user_ids=user_ids, | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |                 Groups.update_group_by_id( | 
					
						
							|  |  |  |                     id=group_model.id, form_data=update_form, overwrite=False | 
					
						
							|  |  |  |                 ) | 
					
						
							| 
									
										
										
										
											2024-10-16 22:32:57 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-05-07 00:00:35 +08:00
										 |  |  |     async def _process_picture_url( | 
					
						
							|  |  |  |         self, picture_url: str, access_token: str = None | 
					
						
							|  |  |  |     ) -> str: | 
					
						
							|  |  |  |         """Process a picture URL and return a base64 encoded data URL.
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         Args: | 
					
						
							|  |  |  |             picture_url: The URL of the picture to process | 
					
						
							|  |  |  |             access_token: Optional OAuth access token for authenticated requests | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         Returns: | 
					
						
							|  |  |  |             A data URL containing the base64 encoded picture, or "/user.png" if processing fails | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         if not picture_url: | 
					
						
							|  |  |  |             return "/user.png" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             get_kwargs = {} | 
					
						
							|  |  |  |             if access_token: | 
					
						
							|  |  |  |                 get_kwargs["headers"] = { | 
					
						
							|  |  |  |                     "Authorization": f"Bearer {access_token}", | 
					
						
							|  |  |  |                 } | 
					
						
							| 
									
										
										
										
											2025-05-15 03:27:34 +08:00
										 |  |  |             async with aiohttp.ClientSession(trust_env=True) as session: | 
					
						
							| 
									
										
										
										
											2025-05-15 03:33:52 +08:00
										 |  |  |                 async with session.get( | 
					
						
							|  |  |  |                     picture_url, **get_kwargs, ssl=AIOHTTP_CLIENT_SESSION_SSL | 
					
						
							|  |  |  |                 ) as resp: | 
					
						
							| 
									
										
										
										
											2025-05-07 00:00:35 +08:00
										 |  |  |                     if resp.ok: | 
					
						
							|  |  |  |                         picture = await resp.read() | 
					
						
							|  |  |  |                         base64_encoded_picture = base64.b64encode(picture).decode( | 
					
						
							|  |  |  |                             "utf-8" | 
					
						
							|  |  |  |                         ) | 
					
						
							|  |  |  |                         guessed_mime_type = mimetypes.guess_type(picture_url)[0] | 
					
						
							|  |  |  |                         if guessed_mime_type is None: | 
					
						
							|  |  |  |                             guessed_mime_type = "image/jpeg" | 
					
						
							|  |  |  |                         return ( | 
					
						
							|  |  |  |                             f"data:{guessed_mime_type};base64,{base64_encoded_picture}" | 
					
						
							|  |  |  |                         ) | 
					
						
							|  |  |  |                     else: | 
					
						
							|  |  |  |                         log.warning( | 
					
						
							|  |  |  |                             f"Failed to fetch profile picture from {picture_url}" | 
					
						
							|  |  |  |                         ) | 
					
						
							|  |  |  |                         return "/user.png" | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							|  |  |  |             log.error(f"Error processing profile picture '{picture_url}': {e}") | 
					
						
							|  |  |  |             return "/user.png" | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-16 16:11:18 +08:00
										 |  |  |     async def handle_login(self, request, provider): | 
					
						
							| 
									
										
										
										
											2024-10-16 22:32:57 +08:00
										 |  |  |         if provider not in OAUTH_PROVIDERS: | 
					
						
							|  |  |  |             raise HTTPException(404) | 
					
						
							|  |  |  |         # If the provider has a custom redirect URL, use that, otherwise automatically generate one | 
					
						
							|  |  |  |         redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for( | 
					
						
							| 
									
										
										
										
											2025-09-26 00:02:49 +08:00
										 |  |  |             "oauth_login_callback", provider=provider | 
					
						
							| 
									
										
										
										
											2024-10-16 22:32:57 +08:00
										 |  |  |         ) | 
					
						
							|  |  |  |         client = self.get_client(provider) | 
					
						
							|  |  |  |         if client is None: | 
					
						
							|  |  |  |             raise HTTPException(404) | 
					
						
							|  |  |  |         return await client.authorize_redirect(request, redirect_uri) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-16 16:11:18 +08:00
										 |  |  |     async def handle_callback(self, request, provider, response): | 
					
						
							| 
									
										
										
										
											2024-10-16 22:32:57 +08:00
										 |  |  |         if provider not in OAUTH_PROVIDERS: | 
					
						
							|  |  |  |             raise HTTPException(404) | 
					
						
							| 
									
										
										
										
											2025-09-07 05:48:52 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         error_message = None | 
					
						
							| 
									
										
										
										
											2024-10-16 22:32:57 +08:00
										 |  |  |         try: | 
					
						
							| 
									
										
										
										
											2025-09-07 05:48:52 +08:00
										 |  |  |             client = self.get_client(provider) | 
					
						
							|  |  |  |             try: | 
					
						
							|  |  |  |                 token = await client.authorize_access_token(request) | 
					
						
							|  |  |  |             except Exception as e: | 
					
						
							|  |  |  |                 log.warning(f"OAuth callback error: {e}") | 
					
						
							|  |  |  |                 raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) | 
					
						
							| 
									
										
										
										
											2025-09-08 18:36:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |             # Try to get userinfo from the token first, some providers include it there | 
					
						
							| 
									
										
										
										
											2025-09-07 05:48:52 +08:00
										 |  |  |             user_data: UserInfo = token.get("userinfo") | 
					
						
							|  |  |  |             if ( | 
					
						
							|  |  |  |                 (not user_data) | 
					
						
							|  |  |  |                 or (auth_manager_config.OAUTH_EMAIL_CLAIM not in user_data) | 
					
						
							|  |  |  |                 or (auth_manager_config.OAUTH_USERNAME_CLAIM not in user_data) | 
					
						
							|  |  |  |             ): | 
					
						
							|  |  |  |                 user_data: UserInfo = await client.userinfo(token=token) | 
					
						
							| 
									
										
										
										
											2025-09-17 00:16:08 +08:00
										 |  |  |             if ( | 
					
						
							|  |  |  |                 provider == "feishu" | 
					
						
							|  |  |  |                 and isinstance(user_data, dict) | 
					
						
							|  |  |  |                 and "data" in user_data | 
					
						
							|  |  |  |             ): | 
					
						
							| 
									
										
										
										
											2025-07-30 16:00:15 +08:00
										 |  |  |                 user_data = user_data["data"] | 
					
						
							| 
									
										
										
										
											2025-09-07 05:48:52 +08:00
										 |  |  |             if not user_data: | 
					
						
							|  |  |  |                 log.warning(f"OAuth callback failed, user data is missing: {token}") | 
					
						
							|  |  |  |                 raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) | 
					
						
							| 
									
										
										
										
											2024-10-16 22:32:57 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-08 18:36:00 +08:00
										 |  |  |             # Extract the "sub" claim, using custom claim if configured | 
					
						
							| 
									
										
										
										
											2025-09-07 05:48:52 +08:00
										 |  |  |             if auth_manager_config.OAUTH_SUB_CLAIM: | 
					
						
							|  |  |  |                 sub = user_data.get(auth_manager_config.OAUTH_SUB_CLAIM) | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 # Fallback to the default sub claim if not configured | 
					
						
							|  |  |  |                 sub = user_data.get(OAUTH_PROVIDERS[provider].get("sub_claim", "sub")) | 
					
						
							|  |  |  |             if not sub: | 
					
						
							|  |  |  |                 log.warning(f"OAuth callback failed, sub is missing: {user_data}") | 
					
						
							|  |  |  |                 raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             provider_sub = f"{provider}@{sub}" | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-08 18:36:00 +08:00
										 |  |  |             # Email extraction | 
					
						
							| 
									
										
										
										
											2025-09-07 05:48:52 +08:00
										 |  |  |             email_claim = auth_manager_config.OAUTH_EMAIL_CLAIM | 
					
						
							|  |  |  |             email = user_data.get(email_claim, "") | 
					
						
							|  |  |  |             # We currently mandate that email addresses are provided | 
					
						
							|  |  |  |             if not email: | 
					
						
							|  |  |  |                 # If the provider is GitHub,and public email is not provided, we can use the access token to fetch the user's email | 
					
						
							|  |  |  |                 if provider == "github": | 
					
						
							|  |  |  |                     try: | 
					
						
							|  |  |  |                         access_token = token.get("access_token") | 
					
						
							|  |  |  |                         headers = {"Authorization": f"Bearer {access_token}"} | 
					
						
							|  |  |  |                         async with aiohttp.ClientSession(trust_env=True) as session: | 
					
						
							|  |  |  |                             async with session.get( | 
					
						
							|  |  |  |                                 "https://api.github.com/user/emails", | 
					
						
							|  |  |  |                                 headers=headers, | 
					
						
							|  |  |  |                                 ssl=AIOHTTP_CLIENT_SESSION_SSL, | 
					
						
							|  |  |  |                             ) as resp: | 
					
						
							|  |  |  |                                 if resp.ok: | 
					
						
							|  |  |  |                                     emails = await resp.json() | 
					
						
							|  |  |  |                                     # use the primary email as the user's email | 
					
						
							|  |  |  |                                     primary_email = next( | 
					
						
							|  |  |  |                                         ( | 
					
						
							|  |  |  |                                             e["email"] | 
					
						
							|  |  |  |                                             for e in emails | 
					
						
							|  |  |  |                                             if e.get("primary") | 
					
						
							|  |  |  |                                         ), | 
					
						
							|  |  |  |                                         None, | 
					
						
							| 
									
										
										
										
											2025-02-20 17:01:29 +08:00
										 |  |  |                                     ) | 
					
						
							| 
									
										
										
										
											2025-09-07 05:48:52 +08:00
										 |  |  |                                     if primary_email: | 
					
						
							|  |  |  |                                         email = primary_email | 
					
						
							|  |  |  |                                     else: | 
					
						
							|  |  |  |                                         log.warning( | 
					
						
							|  |  |  |                                             "No primary email found in GitHub response" | 
					
						
							|  |  |  |                                         ) | 
					
						
							|  |  |  |                                         raise HTTPException( | 
					
						
							|  |  |  |                                             400, detail=ERROR_MESSAGES.INVALID_CRED | 
					
						
							|  |  |  |                                         ) | 
					
						
							|  |  |  |                                 else: | 
					
						
							|  |  |  |                                     log.warning("Failed to fetch GitHub email") | 
					
						
							| 
									
										
										
										
											2025-02-20 17:01:29 +08:00
										 |  |  |                                     raise HTTPException( | 
					
						
							|  |  |  |                                         400, detail=ERROR_MESSAGES.INVALID_CRED | 
					
						
							|  |  |  |                                     ) | 
					
						
							| 
									
										
										
										
											2025-09-07 05:48:52 +08:00
										 |  |  |                     except Exception as e: | 
					
						
							|  |  |  |                         log.warning(f"Error fetching GitHub email: {e}") | 
					
						
							|  |  |  |                         raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) | 
					
						
							|  |  |  |                 else: | 
					
						
							|  |  |  |                     log.warning(f"OAuth callback failed, email is missing: {user_data}") | 
					
						
							| 
									
										
										
										
											2025-02-20 15:06:07 +08:00
										 |  |  |                     raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) | 
					
						
							| 
									
										
										
										
											2025-09-07 05:48:52 +08:00
										 |  |  |             email = email.lower() | 
					
						
							| 
									
										
										
										
											2025-09-08 18:36:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |             # If allowed domains are configured, check if the email domain is in the list | 
					
						
							| 
									
										
										
										
											2025-09-07 05:48:52 +08:00
										 |  |  |             if ( | 
					
						
							|  |  |  |                 "*" not in auth_manager_config.OAUTH_ALLOWED_DOMAINS | 
					
						
							|  |  |  |                 and email.split("@")[-1] | 
					
						
							|  |  |  |                 not in auth_manager_config.OAUTH_ALLOWED_DOMAINS | 
					
						
							|  |  |  |             ): | 
					
						
							|  |  |  |                 log.warning( | 
					
						
							|  |  |  |                     f"OAuth callback failed, e-mail domain is not in the list of allowed domains: {user_data}" | 
					
						
							|  |  |  |                 ) | 
					
						
							| 
									
										
										
										
											2025-02-20 15:06:07 +08:00
										 |  |  |                 raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) | 
					
						
							| 
									
										
										
										
											2025-09-07 05:48:52 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |             # Check if the user exists | 
					
						
							|  |  |  |             user = Users.get_user_by_oauth_sub(provider_sub) | 
					
						
							|  |  |  |             if not user: | 
					
						
							|  |  |  |                 # If the user does not exist, check if merging is enabled | 
					
						
							|  |  |  |                 if auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL: | 
					
						
							|  |  |  |                     # Check if the user exists by email | 
					
						
							|  |  |  |                     user = Users.get_user_by_email(email) | 
					
						
							|  |  |  |                     if user: | 
					
						
							|  |  |  |                         # Update the user with the new oauth sub | 
					
						
							|  |  |  |                         Users.update_user_oauth_sub_by_id(user.id, provider_sub) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             if user: | 
					
						
							|  |  |  |                 determined_role = self.get_user_role(user, user_data) | 
					
						
							|  |  |  |                 if user.role != determined_role: | 
					
						
							|  |  |  |                     Users.update_user_role_by_id(user.id, determined_role) | 
					
						
							|  |  |  |                 # Update profile picture if enabled and different from current | 
					
						
							|  |  |  |                 if auth_manager_config.OAUTH_UPDATE_PICTURE_ON_LOGIN: | 
					
						
							|  |  |  |                     picture_claim = auth_manager_config.OAUTH_PICTURE_CLAIM | 
					
						
							|  |  |  |                     if picture_claim: | 
					
						
							|  |  |  |                         new_picture_url = user_data.get( | 
					
						
							|  |  |  |                             picture_claim, | 
					
						
							|  |  |  |                             OAUTH_PROVIDERS[provider].get("picture_url", ""), | 
					
						
							| 
									
										
										
										
											2025-05-07 00:00:35 +08:00
										 |  |  |                         ) | 
					
						
							| 
									
										
										
										
											2025-09-07 05:48:52 +08:00
										 |  |  |                         processed_picture_url = await self._process_picture_url( | 
					
						
							|  |  |  |                             new_picture_url, token.get("access_token") | 
					
						
							|  |  |  |                         ) | 
					
						
							|  |  |  |                         if processed_picture_url != user.profile_image_url: | 
					
						
							|  |  |  |                             Users.update_user_profile_image_url_by_id( | 
					
						
							|  |  |  |                                 user.id, processed_picture_url | 
					
						
							|  |  |  |                             ) | 
					
						
							|  |  |  |                             log.debug(f"Updated profile picture for user {user.email}") | 
					
						
							| 
									
										
										
										
											2025-09-08 18:36:00 +08:00
										 |  |  |             else: | 
					
						
							| 
									
										
										
										
											2025-09-07 05:48:52 +08:00
										 |  |  |                 # If the user does not exist, check if signups are enabled | 
					
						
							|  |  |  |                 if auth_manager_config.ENABLE_OAUTH_SIGNUP: | 
					
						
							|  |  |  |                     # Check if an existing user with the same email already exists | 
					
						
							|  |  |  |                     existing_user = Users.get_user_by_email(email) | 
					
						
							|  |  |  |                     if existing_user: | 
					
						
							|  |  |  |                         raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     picture_claim = auth_manager_config.OAUTH_PICTURE_CLAIM | 
					
						
							|  |  |  |                     if picture_claim: | 
					
						
							|  |  |  |                         picture_url = user_data.get( | 
					
						
							|  |  |  |                             picture_claim, | 
					
						
							|  |  |  |                             OAUTH_PROVIDERS[provider].get("picture_url", ""), | 
					
						
							|  |  |  |                         ) | 
					
						
							|  |  |  |                         picture_url = await self._process_picture_url( | 
					
						
							|  |  |  |                             picture_url, token.get("access_token") | 
					
						
							|  |  |  |                         ) | 
					
						
							|  |  |  |                     else: | 
					
						
							|  |  |  |                         picture_url = "/user.png" | 
					
						
							|  |  |  |                     username_claim = auth_manager_config.OAUTH_USERNAME_CLAIM | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     name = user_data.get(username_claim) | 
					
						
							|  |  |  |                     if not name: | 
					
						
							|  |  |  |                         log.warning("Username claim is missing, using email as name") | 
					
						
							|  |  |  |                         name = email | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     user = Auths.insert_new_auth( | 
					
						
							|  |  |  |                         email=email, | 
					
						
							|  |  |  |                         password=get_password_hash( | 
					
						
							|  |  |  |                             str(uuid.uuid4()) | 
					
						
							|  |  |  |                         ),  # Random password, not used | 
					
						
							|  |  |  |                         name=name, | 
					
						
							|  |  |  |                         profile_image_url=picture_url, | 
					
						
							| 
									
										
										
										
											2025-09-08 18:36:00 +08:00
										 |  |  |                         role=self.get_user_role(None, user_data), | 
					
						
							| 
									
										
										
										
											2025-09-07 05:48:52 +08:00
										 |  |  |                         oauth_sub=provider_sub, | 
					
						
							| 
									
										
										
										
											2025-05-07 00:00:35 +08:00
										 |  |  |                     ) | 
					
						
							| 
									
										
										
										
											2024-10-16 22:32:57 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-07 05:48:52 +08:00
										 |  |  |                     if auth_manager_config.WEBHOOK_URL: | 
					
						
							|  |  |  |                         await post_webhook( | 
					
						
							|  |  |  |                             WEBUI_NAME, | 
					
						
							|  |  |  |                             auth_manager_config.WEBHOOK_URL, | 
					
						
							|  |  |  |                             WEBHOOK_MESSAGES.USER_SIGNUP(user.name), | 
					
						
							|  |  |  |                             { | 
					
						
							|  |  |  |                                 "action": "signup", | 
					
						
							|  |  |  |                                 "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name), | 
					
						
							|  |  |  |                                 "user": user.model_dump_json(exclude_none=True), | 
					
						
							|  |  |  |                             }, | 
					
						
							|  |  |  |                         ) | 
					
						
							|  |  |  |                 else: | 
					
						
							|  |  |  |                     raise HTTPException( | 
					
						
							|  |  |  |                         status.HTTP_403_FORBIDDEN, | 
					
						
							|  |  |  |                         detail=ERROR_MESSAGES.ACCESS_PROHIBITED, | 
					
						
							| 
									
										
										
										
											2024-10-16 22:32:57 +08:00
										 |  |  |                     ) | 
					
						
							| 
									
										
										
										
											2024-10-16 15:42:47 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-07 05:48:52 +08:00
										 |  |  |             jwt_token = create_token( | 
					
						
							|  |  |  |                 data={"id": user.id}, | 
					
						
							|  |  |  |                 expires_delta=parse_duration(auth_manager_config.JWT_EXPIRES_IN), | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             if ( | 
					
						
							|  |  |  |                 auth_manager_config.ENABLE_OAUTH_GROUP_MANAGEMENT | 
					
						
							|  |  |  |                 and user.role != "admin" | 
					
						
							|  |  |  |             ): | 
					
						
							|  |  |  |                 self.update_user_groups( | 
					
						
							|  |  |  |                     user=user, | 
					
						
							|  |  |  |                     user_data=user_data, | 
					
						
							|  |  |  |                     default_permissions=request.app.state.config.USER_PERMISSIONS, | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							|  |  |  |             log.error(f"Error during OAuth process: {e}") | 
					
						
							|  |  |  |             error_message = ( | 
					
						
							|  |  |  |                 e.detail | 
					
						
							|  |  |  |                 if isinstance(e, HTTPException) and e.detail | 
					
						
							|  |  |  |                 else ERROR_MESSAGES.DEFAULT("Error during OAuth process") | 
					
						
							| 
									
										
										
										
											2024-12-18 05:51:29 +08:00
										 |  |  |             ) | 
					
						
							| 
									
										
										
										
											2024-12-18 03:38:07 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-25 13:28:13 +08:00
										 |  |  |         redirect_base_url = ( | 
					
						
							|  |  |  |             str(request.app.state.config.WEBUI_URL or request.base_url) | 
					
						
							|  |  |  |         ).rstrip("/") | 
					
						
							| 
									
										
										
										
											2025-08-14 06:00:38 +08:00
										 |  |  |         redirect_url = f"{redirect_base_url}/auth" | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-07 05:48:52 +08:00
										 |  |  |         if error_message: | 
					
						
							|  |  |  |             redirect_url = f"{redirect_url}?error={error_message}" | 
					
						
							|  |  |  |             return RedirectResponse(url=redirect_url, headers=response.headers) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-08-14 06:00:38 +08:00
										 |  |  |         response = RedirectResponse(url=redirect_url, headers=response.headers) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-16 22:32:57 +08:00
										 |  |  |         # Set the cookie token | 
					
						
							| 
									
										
										
										
											2025-08-14 06:00:38 +08:00
										 |  |  |         # Redirect back to the frontend with the JWT token | 
					
						
							| 
									
										
										
										
											2024-10-16 22:32:57 +08:00
										 |  |  |         response.set_cookie( | 
					
						
							|  |  |  |             key="token", | 
					
						
							|  |  |  |             value=jwt_token, | 
					
						
							| 
									
										
										
										
											2025-08-07 01:02:54 +08:00
										 |  |  |             httponly=False,  # Required for frontend access | 
					
						
							| 
									
										
										
										
											2025-01-23 22:16:50 +08:00
										 |  |  |             samesite=WEBUI_AUTH_COOKIE_SAME_SITE, | 
					
						
							|  |  |  |             secure=WEBUI_AUTH_COOKIE_SECURE, | 
					
						
							| 
									
										
										
										
											2024-10-16 22:32:57 +08:00
										 |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-10-16 15:42:47 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-08 22:05:43 +08:00
										 |  |  |         # Legacy cookies for compatibility with older frontend versions | 
					
						
							|  |  |  |         if ENABLE_OAUTH_ID_TOKEN_COOKIE: | 
					
						
							|  |  |  |             response.set_cookie( | 
					
						
							|  |  |  |                 key="oauth_id_token", | 
					
						
							|  |  |  |                 value=token.get("id_token"), | 
					
						
							|  |  |  |                 httponly=True, | 
					
						
							|  |  |  |                 samesite=WEBUI_AUTH_COOKIE_SAME_SITE, | 
					
						
							|  |  |  |                 secure=WEBUI_AUTH_COOKIE_SECURE, | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2025-09-08 18:18:25 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-08 22:05:43 +08:00
										 |  |  |         try: | 
					
						
							|  |  |  |             # Add timestamp for tracking | 
					
						
							|  |  |  |             token["issued_at"] = datetime.now().timestamp() | 
					
						
							| 
									
										
										
										
											2025-09-01 03:42:34 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-08 22:05:43 +08:00
										 |  |  |             # Calculate expires_at if we have expires_in | 
					
						
							|  |  |  |             if "expires_in" in token and "expires_at" not in token: | 
					
						
							|  |  |  |                 token["expires_at"] = datetime.now().timestamp() + token["expires_in"] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-08 22:52:59 +08:00
										 |  |  |             # Clean up any existing sessions for this user/provider first | 
					
						
							|  |  |  |             sessions = OAuthSessions.get_sessions_by_user_id(user.id) | 
					
						
							|  |  |  |             for session in sessions: | 
					
						
							|  |  |  |                 if session.provider == provider: | 
					
						
							|  |  |  |                     OAuthSessions.delete_session_by_id(session.id) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-08 22:17:11 +08:00
										 |  |  |             session = OAuthSessions.create_session( | 
					
						
							| 
									
										
										
										
											2025-09-08 22:05:43 +08:00
										 |  |  |                 user_id=user.id, | 
					
						
							|  |  |  |                 provider=provider, | 
					
						
							|  |  |  |                 token=token, | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             response.set_cookie( | 
					
						
							|  |  |  |                 key="oauth_session_id", | 
					
						
							| 
									
										
										
										
											2025-09-08 22:17:11 +08:00
										 |  |  |                 value=session.id, | 
					
						
							| 
									
										
										
										
											2025-09-08 22:05:43 +08:00
										 |  |  |                 httponly=True, | 
					
						
							|  |  |  |                 samesite=WEBUI_AUTH_COOKIE_SAME_SITE, | 
					
						
							|  |  |  |                 secure=WEBUI_AUTH_COOKIE_SECURE, | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             log.info( | 
					
						
							|  |  |  |                 f"Stored OAuth session server-side for user {user.id}, provider {provider}" | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							|  |  |  |             log.error(f"Failed to store OAuth session server-side: {e}") | 
					
						
							| 
									
										
										
										
											2025-09-08 18:18:25 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-08-14 06:00:38 +08:00
										 |  |  |         return response |