1339 lines
		
	
	
		
			54 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			1339 lines
		
	
	
		
			54 KiB
		
	
	
	
		
			Python
		
	
	
	
| import base64
 | |
| import hashlib
 | |
| import logging
 | |
| import mimetypes
 | |
| import sys
 | |
| import urllib
 | |
| import uuid
 | |
| import json
 | |
| from datetime import datetime, timedelta
 | |
| 
 | |
| import re
 | |
| import fnmatch
 | |
| import time
 | |
| import secrets
 | |
| from cryptography.fernet import Fernet
 | |
| 
 | |
| 
 | |
| import aiohttp
 | |
| from authlib.integrations.starlette_client import OAuth
 | |
| from authlib.oidc.core import UserInfo
 | |
| from fastapi import (
 | |
|     HTTPException,
 | |
|     status,
 | |
| )
 | |
| from starlette.responses import RedirectResponse
 | |
| from typing import Optional
 | |
| 
 | |
| 
 | |
| from open_webui.models.auths import Auths
 | |
| from open_webui.models.oauth_sessions import OAuthSessions
 | |
| from open_webui.models.users import Users
 | |
| 
 | |
| 
 | |
| from open_webui.models.groups import Groups, GroupModel, GroupUpdateForm, GroupForm
 | |
| from open_webui.config import (
 | |
|     DEFAULT_USER_ROLE,
 | |
|     ENABLE_OAUTH_SIGNUP,
 | |
|     OAUTH_MERGE_ACCOUNTS_BY_EMAIL,
 | |
|     OAUTH_PROVIDERS,
 | |
|     ENABLE_OAUTH_ROLE_MANAGEMENT,
 | |
|     ENABLE_OAUTH_GROUP_MANAGEMENT,
 | |
|     ENABLE_OAUTH_GROUP_CREATION,
 | |
|     OAUTH_BLOCKED_GROUPS,
 | |
|     OAUTH_ROLES_CLAIM,
 | |
|     OAUTH_SUB_CLAIM,
 | |
|     OAUTH_GROUPS_CLAIM,
 | |
|     OAUTH_EMAIL_CLAIM,
 | |
|     OAUTH_PICTURE_CLAIM,
 | |
|     OAUTH_USERNAME_CLAIM,
 | |
|     OAUTH_ALLOWED_ROLES,
 | |
|     OAUTH_ADMIN_ROLES,
 | |
|     OAUTH_ALLOWED_DOMAINS,
 | |
|     OAUTH_UPDATE_PICTURE_ON_LOGIN,
 | |
|     WEBHOOK_URL,
 | |
|     JWT_EXPIRES_IN,
 | |
|     AppConfig,
 | |
| )
 | |
| from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
 | |
| from open_webui.env import (
 | |
|     AIOHTTP_CLIENT_SESSION_SSL,
 | |
|     WEBUI_NAME,
 | |
|     WEBUI_AUTH_COOKIE_SAME_SITE,
 | |
|     WEBUI_AUTH_COOKIE_SECURE,
 | |
|     ENABLE_OAUTH_ID_TOKEN_COOKIE,
 | |
|     OAUTH_CLIENT_INFO_ENCRYPTION_KEY,
 | |
| )
 | |
| from open_webui.utils.misc import parse_duration
 | |
| from open_webui.utils.auth import get_password_hash, create_token
 | |
| from open_webui.utils.webhook import post_webhook
 | |
| 
 | |
| from mcp.shared.auth import (
 | |
|     OAuthClientMetadata,
 | |
|     OAuthMetadata,
 | |
| )
 | |
| 
 | |
| 
 | |
| class OAuthClientInformationFull(OAuthClientMetadata):
 | |
|     issuer: Optional[str] = None  # URL of the OAuth server that issued this client
 | |
| 
 | |
|     client_id: str
 | |
|     client_secret: str | None = None
 | |
|     client_id_issued_at: int | None = None
 | |
|     client_secret_expires_at: int | None = None
 | |
| 
 | |
| 
 | |
| from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
 | |
| 
 | |
| logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
 | |
| log = logging.getLogger(__name__)
 | |
| log.setLevel(SRC_LOG_LEVELS["OAUTH"])
 | |
| 
 | |
| auth_manager_config = AppConfig()
 | |
| auth_manager_config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
 | |
| auth_manager_config.ENABLE_OAUTH_SIGNUP = ENABLE_OAUTH_SIGNUP
 | |
| auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL = OAUTH_MERGE_ACCOUNTS_BY_EMAIL
 | |
| auth_manager_config.ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT
 | |
| auth_manager_config.ENABLE_OAUTH_GROUP_MANAGEMENT = ENABLE_OAUTH_GROUP_MANAGEMENT
 | |
| auth_manager_config.ENABLE_OAUTH_GROUP_CREATION = ENABLE_OAUTH_GROUP_CREATION
 | |
| auth_manager_config.OAUTH_BLOCKED_GROUPS = OAUTH_BLOCKED_GROUPS
 | |
| auth_manager_config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM
 | |
| auth_manager_config.OAUTH_SUB_CLAIM = OAUTH_SUB_CLAIM
 | |
| auth_manager_config.OAUTH_GROUPS_CLAIM = OAUTH_GROUPS_CLAIM
 | |
| auth_manager_config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM
 | |
| auth_manager_config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
 | |
| auth_manager_config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM
 | |
| auth_manager_config.OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES
 | |
| auth_manager_config.OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES
 | |
| auth_manager_config.OAUTH_ALLOWED_DOMAINS = OAUTH_ALLOWED_DOMAINS
 | |
| auth_manager_config.WEBHOOK_URL = WEBHOOK_URL
 | |
| auth_manager_config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
 | |
| auth_manager_config.OAUTH_UPDATE_PICTURE_ON_LOGIN = OAUTH_UPDATE_PICTURE_ON_LOGIN
 | |
| 
 | |
| 
 | |
| FERNET = None
 | |
| 
 | |
| if len(OAUTH_CLIENT_INFO_ENCRYPTION_KEY) != 44:
 | |
|     key_bytes = hashlib.sha256(OAUTH_CLIENT_INFO_ENCRYPTION_KEY.encode()).digest()
 | |
|     OAUTH_CLIENT_INFO_ENCRYPTION_KEY = base64.urlsafe_b64encode(key_bytes)
 | |
| else:
 | |
|     OAUTH_CLIENT_INFO_ENCRYPTION_KEY = OAUTH_CLIENT_INFO_ENCRYPTION_KEY.encode()
 | |
| 
 | |
| try:
 | |
|     FERNET = Fernet(OAUTH_CLIENT_INFO_ENCRYPTION_KEY)
 | |
| except Exception as e:
 | |
|     log.error(f"Error initializing Fernet with provided key: {e}")
 | |
|     raise
 | |
| 
 | |
| 
 | |
| def encrypt_data(data) -> str:
 | |
|     """Encrypt data for storage"""
 | |
|     try:
 | |
|         data_json = json.dumps(data)
 | |
|         encrypted = FERNET.encrypt(data_json.encode()).decode()
 | |
|         return encrypted
 | |
|     except Exception as e:
 | |
|         log.error(f"Error encrypting data: {e}")
 | |
|         raise
 | |
| 
 | |
| 
 | |
| def decrypt_data(data: str):
 | |
|     """Decrypt data from storage"""
 | |
|     try:
 | |
|         decrypted = FERNET.decrypt(data.encode()).decode()
 | |
|         return json.loads(decrypted)
 | |
|     except Exception as e:
 | |
|         log.error(f"Error decrypting data: {e}")
 | |
|         raise
 | |
| 
 | |
| 
 | |
| def is_in_blocked_groups(group_name: str, groups: list) -> bool:
 | |
|     """
 | |
|     Check if a group name matches any blocked pattern.
 | |
|     Supports exact matches, shell-style wildcards (*, ?), and regex patterns.
 | |
| 
 | |
|     Args:
 | |
|         group_name: The group name to check
 | |
|         groups: List of patterns to match against
 | |
| 
 | |
|     Returns:
 | |
|         True if the group is blocked, False otherwise
 | |
|     """
 | |
|     if not groups:
 | |
|         return False
 | |
| 
 | |
|     for group_pattern in groups:
 | |
|         if not group_pattern:  # Skip empty patterns
 | |
|             continue
 | |
| 
 | |
|         # Exact match
 | |
|         if group_name == group_pattern:
 | |
|             return True
 | |
| 
 | |
|         # Try as regex pattern first if it contains regex-specific characters
 | |
|         if any(
 | |
|             char in group_pattern
 | |
|             for char in ["^", "$", "[", "]", "(", ")", "{", "}", "+", "\\", "|"]
 | |
|         ):
 | |
|             try:
 | |
|                 # Use the original pattern as-is for regex matching
 | |
|                 if re.search(group_pattern, group_name):
 | |
|                     return True
 | |
|             except re.error:
 | |
|                 # If regex is invalid, fall through to wildcard check
 | |
|                 pass
 | |
| 
 | |
|         # Shell-style wildcard match (supports * and ?)
 | |
|         if "*" in group_pattern or "?" in group_pattern:
 | |
|             if fnmatch.fnmatch(group_name, group_pattern):
 | |
|                 return True
 | |
| 
 | |
|     return False
 | |
| 
 | |
| 
 | |
| def get_parsed_and_base_url(server_url) -> tuple[urllib.parse.ParseResult, str]:
 | |
|     parsed = urllib.parse.urlparse(server_url)
 | |
|     base_url = f"{parsed.scheme}://{parsed.netloc}"
 | |
|     return parsed, base_url
 | |
| 
 | |
| 
 | |
| def get_discovery_urls(server_url) -> list[str]:
 | |
|     parsed, base_url = get_parsed_and_base_url(server_url)
 | |
| 
 | |
|     urls = [
 | |
|         urllib.parse.urljoin(base_url, "/.well-known/oauth-authorization-server"),
 | |
|         urllib.parse.urljoin(base_url, "/.well-known/openid-configuration"),
 | |
|     ]
 | |
| 
 | |
|     if parsed.path and parsed.path != "/":
 | |
|         urls.append(
 | |
|             urllib.parse.urljoin(
 | |
|                 base_url,
 | |
|                 f"/.well-known/oauth-authorization-server{parsed.path.rstrip('/')}",
 | |
|             )
 | |
|         )
 | |
|         urls.append(
 | |
|             urllib.parse.urljoin(
 | |
|                 base_url, f"/.well-known/openid-configuration{parsed.path.rstrip('/')}"
 | |
|             )
 | |
|         )
 | |
| 
 | |
|     return urls
 | |
| 
 | |
| 
 | |
| # TODO: Some OAuth providers require Initial Access Tokens (IATs) for dynamic client registration.
 | |
| # This is not currently supported.
 | |
| async def get_oauth_client_info_with_dynamic_client_registration(
 | |
|     request,
 | |
|     client_id: str,
 | |
|     oauth_server_url: str,
 | |
|     oauth_server_key: Optional[str] = None,
 | |
| ) -> OAuthClientInformationFull:
 | |
|     try:
 | |
|         oauth_server_metadata = None
 | |
|         oauth_server_metadata_url = None
 | |
| 
 | |
|         redirect_base_url = (
 | |
|             str(request.app.state.config.WEBUI_URL or request.base_url)
 | |
|         ).rstrip("/")
 | |
| 
 | |
|         oauth_client_metadata = OAuthClientMetadata(
 | |
|             client_name="Open WebUI",
 | |
|             redirect_uris=[f"{redirect_base_url}/oauth/clients/{client_id}/callback"],
 | |
|             grant_types=["authorization_code", "refresh_token"],
 | |
|             response_types=["code"],
 | |
|             token_endpoint_auth_method="client_secret_post",
 | |
|         )
 | |
| 
 | |
|         # Attempt to fetch OAuth server metadata to get registration endpoint & scopes
 | |
|         discovery_urls = get_discovery_urls(oauth_server_url)
 | |
|         for url in discovery_urls:
 | |
|             async with aiohttp.ClientSession() as session:
 | |
|                 async with session.get(
 | |
|                     url, ssl=AIOHTTP_CLIENT_SESSION_SSL
 | |
|                 ) as oauth_server_metadata_response:
 | |
|                     if oauth_server_metadata_response.status == 200:
 | |
|                         try:
 | |
|                             oauth_server_metadata = OAuthMetadata.model_validate(
 | |
|                                 await oauth_server_metadata_response.json()
 | |
|                             )
 | |
|                             oauth_server_metadata_url = url
 | |
|                             if (
 | |
|                                 oauth_client_metadata.scope is None
 | |
|                                 and oauth_server_metadata.scopes_supported is not None
 | |
|                             ):
 | |
|                                 oauth_client_metadata.scope = " ".join(
 | |
|                                     oauth_server_metadata.scopes_supported
 | |
|                                 )
 | |
|                             break
 | |
|                         except Exception as e:
 | |
|                             log.error(f"Error parsing OAuth metadata from {url}: {e}")
 | |
|                             continue
 | |
| 
 | |
|         registration_url = None
 | |
|         if oauth_server_metadata and oauth_server_metadata.registration_endpoint:
 | |
|             registration_url = str(oauth_server_metadata.registration_endpoint)
 | |
|         else:
 | |
|             _, base_url = get_parsed_and_base_url(oauth_server_url)
 | |
|             registration_url = urllib.parse.urljoin(base_url, "/register")
 | |
| 
 | |
|         registration_data = oauth_client_metadata.model_dump(
 | |
|             exclude_none=True,
 | |
|             mode="json",
 | |
|             by_alias=True,
 | |
|         )
 | |
| 
 | |
|         # Perform dynamic client registration and return client info
 | |
|         async with aiohttp.ClientSession() as session:
 | |
|             async with session.post(
 | |
|                 registration_url, json=registration_data, ssl=AIOHTTP_CLIENT_SESSION_SSL
 | |
|             ) as oauth_client_registration_response:
 | |
|                 try:
 | |
|                     registration_response_json = (
 | |
|                         await oauth_client_registration_response.json()
 | |
|                     )
 | |
|                     oauth_client_info = OAuthClientInformationFull.model_validate(
 | |
|                         {
 | |
|                             **registration_response_json,
 | |
|                             **{"issuer": oauth_server_metadata_url},
 | |
|                         }
 | |
|                     )
 | |
|                     log.info(
 | |
|                         f"Dynamic client registration successful at {registration_url}, client_id: {oauth_client_info.client_id}"
 | |
|                     )
 | |
|                     return oauth_client_info
 | |
|                 except Exception as e:
 | |
|                     error_text = None
 | |
|                     try:
 | |
|                         error_text = await oauth_client_registration_response.text()
 | |
|                         log.error(
 | |
|                             f"Dynamic client registration failed at {registration_url}: {oauth_client_registration_response.status} - {error_text}"
 | |
|                         )
 | |
|                     except Exception as e:
 | |
|                         pass
 | |
| 
 | |
|                     log.error(f"Error parsing client registration response: {e}")
 | |
|                     raise Exception(
 | |
|                         f"Dynamic client registration failed: {error_text}"
 | |
|                         if error_text
 | |
|                         else "Error parsing client registration response"
 | |
|                     )
 | |
|         raise Exception("Dynamic client registration failed")
 | |
|     except Exception as e:
 | |
|         log.error(f"Exception during dynamic client registration: {e}")
 | |
|         raise e
 | |
| 
 | |
| 
 | |
| class OAuthClientManager:
 | |
|     def __init__(self, app):
 | |
|         self.oauth = OAuth()
 | |
|         self.app = app
 | |
|         self.clients = {}
 | |
| 
 | |
|     def add_client(self, client_id, oauth_client_info: OAuthClientInformationFull):
 | |
|         self.clients[client_id] = {
 | |
|             "client": self.oauth.register(
 | |
|                 name=client_id,
 | |
|                 client_id=oauth_client_info.client_id,
 | |
|                 client_secret=oauth_client_info.client_secret,
 | |
|                 client_kwargs=(
 | |
|                     {"scope": oauth_client_info.scope}
 | |
|                     if oauth_client_info.scope
 | |
|                     else {}
 | |
|                 ),
 | |
|                 server_metadata_url=(
 | |
|                     oauth_client_info.issuer if oauth_client_info.issuer else None
 | |
|                 ),
 | |
|             ),
 | |
|             "client_info": oauth_client_info,
 | |
|         }
 | |
|         return self.clients[client_id]
 | |
| 
 | |
|     def remove_client(self, client_id):
 | |
|         if client_id in self.clients:
 | |
|             del self.clients[client_id]
 | |
|             log.info(f"Removed OAuth client {client_id}")
 | |
|         return True
 | |
| 
 | |
|     def get_client(self, client_id):
 | |
|         client = self.clients.get(client_id)
 | |
|         return client["client"] if client else None
 | |
| 
 | |
|     def get_client_info(self, client_id):
 | |
|         client = self.clients.get(client_id)
 | |
|         return client["client_info"] if client else None
 | |
| 
 | |
|     def get_server_metadata_url(self, client_id):
 | |
|         if client_id in self.clients:
 | |
|             client = self.clients[client_id]
 | |
|             return (
 | |
|                 client.server_metadata_url
 | |
|                 if hasattr(client, "server_metadata_url")
 | |
|                 else None
 | |
|             )
 | |
|         return None
 | |
| 
 | |
|     async def get_oauth_token(
 | |
|         self, user_id: str, client_id: str, force_refresh: bool = False
 | |
|     ):
 | |
|         """
 | |
|         Get a valid OAuth token for the user, automatically refreshing if needed.
 | |
| 
 | |
|         Args:
 | |
|             user_id: The user ID
 | |
|             client_id: The OAuth client ID (provider)
 | |
|             force_refresh: Force token refresh even if current token appears valid
 | |
| 
 | |
|         Returns:
 | |
|             dict: OAuth token data with access_token, or None if no valid token available
 | |
|         """
 | |
|         try:
 | |
|             # Get the OAuth session
 | |
|             session = OAuthSessions.get_session_by_provider_and_user_id(
 | |
|                 client_id, user_id
 | |
|             )
 | |
|             if not session:
 | |
|                 log.warning(
 | |
|                     f"No OAuth session found for user {user_id}, client_id {client_id}"
 | |
|                 )
 | |
|                 return None
 | |
| 
 | |
|             if force_refresh or datetime.now() + timedelta(
 | |
|                 minutes=5
 | |
|             ) >= datetime.fromtimestamp(session.expires_at):
 | |
|                 log.debug(
 | |
|                     f"Token refresh needed for user {user_id}, client_id {session.provider}"
 | |
|                 )
 | |
|                 refreshed_token = await self._refresh_token(session)
 | |
|                 if refreshed_token:
 | |
|                     return refreshed_token
 | |
|                 else:
 | |
|                     log.warning(
 | |
|                         f"Token refresh failed for user {user_id}, client_id {session.provider}, deleting session {session.id}"
 | |
|                     )
 | |
|                     OAuthSessions.delete_session_by_id(session.id)
 | |
|                     return None
 | |
|             return session.token
 | |
| 
 | |
|         except Exception as e:
 | |
|             log.error(f"Error getting OAuth token for user {user_id}: {e}")
 | |
|             return None
 | |
| 
 | |
|     async def _refresh_token(self, session) -> dict:
 | |
|         """
 | |
|         Refresh an OAuth token if needed, with concurrency protection.
 | |
| 
 | |
|         Args:
 | |
|             session: The OAuth session object
 | |
| 
 | |
|         Returns:
 | |
|             dict: Refreshed token data, or None if refresh failed
 | |
|         """
 | |
|         try:
 | |
|             # Perform the actual refresh
 | |
|             refreshed_token = await self._perform_token_refresh(session)
 | |
| 
 | |
|             if refreshed_token:
 | |
|                 # Update the session with new token data
 | |
|                 session = OAuthSessions.update_session_by_id(
 | |
|                     session.id, refreshed_token
 | |
|                 )
 | |
|                 log.info(f"Successfully refreshed token for session {session.id}")
 | |
|                 return session.token
 | |
|             else:
 | |
|                 log.error(f"Failed to refresh token for session {session.id}")
 | |
|                 return None
 | |
| 
 | |
|         except Exception as e:
 | |
|             log.error(f"Error refreshing token for session {session.id}: {e}")
 | |
|             return None
 | |
| 
 | |
|     async def _perform_token_refresh(self, session) -> dict:
 | |
|         """
 | |
|         Perform the actual OAuth token refresh.
 | |
| 
 | |
|         Args:
 | |
|             session: The OAuth session object
 | |
| 
 | |
|         Returns:
 | |
|             dict: New token data, or None if refresh failed
 | |
|         """
 | |
|         client_id = session.provider
 | |
|         token_data = session.token
 | |
| 
 | |
|         if not token_data.get("refresh_token"):
 | |
|             log.warning(f"No refresh token available for session {session.id}")
 | |
|             return None
 | |
| 
 | |
|         try:
 | |
|             client = self.get_client(client_id)
 | |
|             if not client:
 | |
|                 log.error(f"No OAuth client found for provider {client_id}")
 | |
|                 return None
 | |
| 
 | |
|             token_endpoint = None
 | |
|             async with aiohttp.ClientSession(trust_env=True) as session_http:
 | |
|                 async with session_http.get(
 | |
|                     self.get_server_metadata_url(client_id)
 | |
|                 ) as r:
 | |
|                     if r.status == 200:
 | |
|                         openid_data = await r.json()
 | |
|                         token_endpoint = openid_data.get("token_endpoint")
 | |
|                     else:
 | |
|                         log.error(
 | |
|                             f"Failed to fetch OpenID configuration for client_id {client_id}"
 | |
|                         )
 | |
|             if not token_endpoint:
 | |
|                 log.error(f"No token endpoint found for client_id {client_id}")
 | |
|                 return None
 | |
| 
 | |
|             # Prepare refresh request
 | |
|             refresh_data = {
 | |
|                 "grant_type": "refresh_token",
 | |
|                 "refresh_token": token_data["refresh_token"],
 | |
|                 "client_id": client.client_id,
 | |
|             }
 | |
|             if hasattr(client, "client_secret") and client.client_secret:
 | |
|                 refresh_data["client_secret"] = client.client_secret
 | |
| 
 | |
|             # Make refresh request
 | |
|             async with aiohttp.ClientSession(trust_env=True) as session_http:
 | |
|                 async with session_http.post(
 | |
|                     token_endpoint,
 | |
|                     data=refresh_data,
 | |
|                     headers={"Content-Type": "application/x-www-form-urlencoded"},
 | |
|                     ssl=AIOHTTP_CLIENT_SESSION_SSL,
 | |
|                 ) as r:
 | |
|                     if r.status == 200:
 | |
|                         new_token_data = await r.json()
 | |
| 
 | |
|                         # Merge with existing token data (preserve refresh_token if not provided)
 | |
|                         if "refresh_token" not in new_token_data:
 | |
|                             new_token_data["refresh_token"] = token_data[
 | |
|                                 "refresh_token"
 | |
|                             ]
 | |
| 
 | |
|                         # Add timestamp for tracking
 | |
|                         new_token_data["issued_at"] = datetime.now().timestamp()
 | |
| 
 | |
|                         # Calculate expires_at if we have expires_in
 | |
|                         if (
 | |
|                             "expires_in" in new_token_data
 | |
|                             and "expires_at" not in new_token_data
 | |
|                         ):
 | |
|                             new_token_data["expires_at"] = int(
 | |
|                                 datetime.now().timestamp()
 | |
|                                 + new_token_data["expires_in"]
 | |
|                             )
 | |
| 
 | |
|                         log.debug(f"Token refresh successful for client_id {client_id}")
 | |
|                         return new_token_data
 | |
|                     else:
 | |
|                         error_text = await r.text()
 | |
|                         log.error(
 | |
|                             f"Token refresh failed for client_id {client_id}: {r.status} - {error_text}"
 | |
|                         )
 | |
|                         return None
 | |
| 
 | |
|         except Exception as e:
 | |
|             log.error(f"Exception during token refresh for client_id {client_id}: {e}")
 | |
|             return None
 | |
| 
 | |
|     async def handle_authorize(self, request, client_id: str) -> RedirectResponse:
 | |
|         client = self.get_client(client_id)
 | |
|         if client is None:
 | |
|             raise HTTPException(404)
 | |
| 
 | |
|         client_info = self.get_client_info(client_id)
 | |
|         if client_info is None:
 | |
|             raise HTTPException(404)
 | |
| 
 | |
|         redirect_uri = (
 | |
|             client_info.redirect_uris[0] if client_info.redirect_uris else None
 | |
|         )
 | |
|         return await client.authorize_redirect(request, str(redirect_uri))
 | |
| 
 | |
|     async def handle_callback(self, request, client_id: str, user_id: str, response):
 | |
|         client = self.get_client(client_id)
 | |
|         if client is None:
 | |
|             raise HTTPException(404)
 | |
| 
 | |
|         error_message = None
 | |
|         try:
 | |
|             token = await client.authorize_access_token(request)
 | |
|             if token:
 | |
|                 try:
 | |
|                     # Add timestamp for tracking
 | |
|                     token["issued_at"] = datetime.now().timestamp()
 | |
| 
 | |
|                     # Calculate expires_at if we have expires_in
 | |
|                     if "expires_in" in token and "expires_at" not in token:
 | |
|                         token["expires_at"] = (
 | |
|                             datetime.now().timestamp() + token["expires_in"]
 | |
|                         )
 | |
| 
 | |
|                     # Clean up any existing sessions for this user/client_id first
 | |
|                     sessions = OAuthSessions.get_sessions_by_user_id(user_id)
 | |
|                     for session in sessions:
 | |
|                         if session.provider == client_id:
 | |
|                             OAuthSessions.delete_session_by_id(session.id)
 | |
| 
 | |
|                     session = OAuthSessions.create_session(
 | |
|                         user_id=user_id,
 | |
|                         provider=client_id,
 | |
|                         token=token,
 | |
|                     )
 | |
|                     log.info(
 | |
|                         f"Stored OAuth session server-side for user {user_id}, client_id {client_id}"
 | |
|                     )
 | |
|                 except Exception as e:
 | |
|                     error_message = "Failed to store OAuth session server-side"
 | |
|                     log.error(f"Failed to store OAuth session server-side: {e}")
 | |
|             else:
 | |
|                 error_message = "Failed to obtain OAuth token"
 | |
|                 log.warning(error_message)
 | |
|         except Exception as e:
 | |
|             error_message = "OAuth callback error"
 | |
|             log.warning(f"OAuth callback error: {e}")
 | |
| 
 | |
|         redirect_url = (
 | |
|             str(request.app.state.config.WEBUI_URL or request.base_url)
 | |
|         ).rstrip("/")
 | |
| 
 | |
|         if error_message:
 | |
|             log.debug(error_message)
 | |
|             redirect_url = f"{redirect_url}/?error={error_message}"
 | |
|             return RedirectResponse(url=redirect_url, headers=response.headers)
 | |
| 
 | |
|         response = RedirectResponse(url=redirect_url, headers=response.headers)
 | |
|         return response
 | |
| 
 | |
| 
 | |
| class OAuthManager:
 | |
|     def __init__(self, app):
 | |
|         self.oauth = OAuth()
 | |
|         self.app = app
 | |
| 
 | |
|         self._clients = {}
 | |
|         for _, provider_config in OAUTH_PROVIDERS.items():
 | |
|             provider_config["register"](self.oauth)
 | |
| 
 | |
|     def get_client(self, provider_name):
 | |
|         if provider_name not in self._clients:
 | |
|             self._clients[provider_name] = self.oauth.create_client(provider_name)
 | |
|         return self._clients[provider_name]
 | |
| 
 | |
|     def get_server_metadata_url(self, provider_name):
 | |
|         if provider_name in self._clients:
 | |
|             client = self._clients[provider_name]
 | |
|             return (
 | |
|                 client.server_metadata_url
 | |
|                 if hasattr(client, "server_metadata_url")
 | |
|                 else None
 | |
|             )
 | |
|         return None
 | |
| 
 | |
|     async def get_oauth_token(
 | |
|         self, user_id: str, session_id: str, force_refresh: bool = False
 | |
|     ):
 | |
|         """
 | |
|         Get a valid OAuth token for the user, automatically refreshing if needed.
 | |
| 
 | |
|         Args:
 | |
|             user_id: The user ID
 | |
|             provider: Optional provider name. If None, gets the most recent session.
 | |
|             force_refresh: Force token refresh even if current token appears valid
 | |
| 
 | |
|         Returns:
 | |
|             dict: OAuth token data with access_token, or None if no valid token available
 | |
|         """
 | |
|         try:
 | |
|             # Get the OAuth session
 | |
|             session = OAuthSessions.get_session_by_id_and_user_id(session_id, user_id)
 | |
|             if not session:
 | |
|                 log.warning(
 | |
|                     f"No OAuth session found for user {user_id}, session {session_id}"
 | |
|                 )
 | |
|                 return None
 | |
| 
 | |
|             if force_refresh or datetime.now() + timedelta(
 | |
|                 minutes=5
 | |
|             ) >= datetime.fromtimestamp(session.expires_at):
 | |
|                 log.debug(
 | |
|                     f"Token refresh needed for user {user_id}, provider {session.provider}"
 | |
|                 )
 | |
|                 refreshed_token = await self._refresh_token(session)
 | |
|                 if refreshed_token:
 | |
|                     return refreshed_token
 | |
|                 else:
 | |
|                     log.warning(
 | |
|                         f"Token refresh failed for user {user_id}, provider {session.provider}, deleting session {session.id}"
 | |
|                     )
 | |
|                     OAuthSessions.delete_session_by_id(session.id)
 | |
| 
 | |
|                     return None
 | |
|             return session.token
 | |
| 
 | |
|         except Exception as e:
 | |
|             log.error(f"Error getting OAuth token for user {user_id}: {e}")
 | |
|             return None
 | |
| 
 | |
|     async def _refresh_token(self, session) -> dict:
 | |
|         """
 | |
|         Refresh an OAuth token if needed, with concurrency protection.
 | |
| 
 | |
|         Args:
 | |
|             session: The OAuth session object
 | |
| 
 | |
|         Returns:
 | |
|             dict: Refreshed token data, or None if refresh failed
 | |
|         """
 | |
|         try:
 | |
|             # Perform the actual refresh
 | |
|             refreshed_token = await self._perform_token_refresh(session)
 | |
| 
 | |
|             if refreshed_token:
 | |
|                 # Update the session with new token data
 | |
|                 session = OAuthSessions.update_session_by_id(
 | |
|                     session.id, refreshed_token
 | |
|                 )
 | |
|                 log.info(f"Successfully refreshed token for session {session.id}")
 | |
|                 return session.token
 | |
|             else:
 | |
|                 log.error(f"Failed to refresh token for session {session.id}")
 | |
|                 return None
 | |
| 
 | |
|         except Exception as e:
 | |
|             log.error(f"Error refreshing token for session {session.id}: {e}")
 | |
|             return None
 | |
| 
 | |
|     async def _perform_token_refresh(self, session) -> dict:
 | |
|         """
 | |
|         Perform the actual OAuth token refresh.
 | |
| 
 | |
|         Args:
 | |
|             session: The OAuth session object
 | |
| 
 | |
|         Returns:
 | |
|             dict: New token data, or None if refresh failed
 | |
|         """
 | |
|         provider = session.provider
 | |
|         token_data = session.token
 | |
| 
 | |
|         if not token_data.get("refresh_token"):
 | |
|             log.warning(f"No refresh token available for session {session.id}")
 | |
|             return None
 | |
| 
 | |
|         try:
 | |
|             client = self.get_client(provider)
 | |
|             if not client:
 | |
|                 log.error(f"No OAuth client found for provider {provider}")
 | |
|                 return None
 | |
| 
 | |
|             server_metadata_url = self.get_server_metadata_url(provider)
 | |
|             token_endpoint = None
 | |
|             async with aiohttp.ClientSession(trust_env=True) as session_http:
 | |
|                 async with session_http.get(server_metadata_url) as r:
 | |
|                     if r.status == 200:
 | |
|                         openid_data = await r.json()
 | |
|                         token_endpoint = openid_data.get("token_endpoint")
 | |
|                     else:
 | |
|                         log.error(
 | |
|                             f"Failed to fetch OpenID configuration for provider {provider}"
 | |
|                         )
 | |
|             if not token_endpoint:
 | |
|                 log.error(f"No token endpoint found for provider {provider}")
 | |
|                 return None
 | |
| 
 | |
|             # Prepare refresh request
 | |
|             refresh_data = {
 | |
|                 "grant_type": "refresh_token",
 | |
|                 "refresh_token": token_data["refresh_token"],
 | |
|                 "client_id": client.client_id,
 | |
|             }
 | |
|             # Add client_secret if available (some providers require it)
 | |
|             if hasattr(client, "client_secret") and client.client_secret:
 | |
|                 refresh_data["client_secret"] = client.client_secret
 | |
| 
 | |
|             # Make refresh request
 | |
|             async with aiohttp.ClientSession(trust_env=True) as session_http:
 | |
|                 async with session_http.post(
 | |
|                     token_endpoint,
 | |
|                     data=refresh_data,
 | |
|                     headers={"Content-Type": "application/x-www-form-urlencoded"},
 | |
|                     ssl=AIOHTTP_CLIENT_SESSION_SSL,
 | |
|                 ) as r:
 | |
|                     if r.status == 200:
 | |
|                         new_token_data = await r.json()
 | |
| 
 | |
|                         # Merge with existing token data (preserve refresh_token if not provided)
 | |
|                         if "refresh_token" not in new_token_data:
 | |
|                             new_token_data["refresh_token"] = token_data[
 | |
|                                 "refresh_token"
 | |
|                             ]
 | |
| 
 | |
|                         # Add timestamp for tracking
 | |
|                         new_token_data["issued_at"] = datetime.now().timestamp()
 | |
| 
 | |
|                         # Calculate expires_at if we have expires_in
 | |
|                         if (
 | |
|                             "expires_in" in new_token_data
 | |
|                             and "expires_at" not in new_token_data
 | |
|                         ):
 | |
|                             new_token_data["expires_at"] = int(
 | |
|                                 datetime.now().timestamp()
 | |
|                                 + new_token_data["expires_in"]
 | |
|                             )
 | |
| 
 | |
|                         log.debug(f"Token refresh successful for provider {provider}")
 | |
|                         return new_token_data
 | |
|                     else:
 | |
|                         error_text = await r.text()
 | |
|                         log.error(
 | |
|                             f"Token refresh failed for provider {provider}: {r.status} - {error_text}"
 | |
|                         )
 | |
|                         return None
 | |
| 
 | |
|         except Exception as e:
 | |
|             log.error(f"Exception during token refresh for provider {provider}: {e}")
 | |
|             return None
 | |
| 
 | |
|     def get_user_role(self, user, user_data):
 | |
|         user_count = Users.get_num_users()
 | |
|         if user and user_count == 1:
 | |
|             # If the user is the only user, assign the role "admin" - actually repairs role for single user on login
 | |
|             log.debug("Assigning the only user the admin role")
 | |
|             return "admin"
 | |
|         if not user and user_count == 0:
 | |
|             # If there are no users, assign the role "admin", as the first user will be an admin
 | |
|             log.debug("Assigning the first user the admin role")
 | |
|             return "admin"
 | |
| 
 | |
|         if auth_manager_config.ENABLE_OAUTH_ROLE_MANAGEMENT:
 | |
|             log.debug("Running OAUTH Role management")
 | |
|             oauth_claim = auth_manager_config.OAUTH_ROLES_CLAIM
 | |
|             oauth_allowed_roles = auth_manager_config.OAUTH_ALLOWED_ROLES
 | |
|             oauth_admin_roles = auth_manager_config.OAUTH_ADMIN_ROLES
 | |
|             oauth_roles = []
 | |
|             # Default/fallback role if no matching roles are found
 | |
|             role = auth_manager_config.DEFAULT_USER_ROLE
 | |
| 
 | |
|             # Next block extracts the roles from the user data, accepting nested claims of any depth
 | |
|             if oauth_claim and oauth_allowed_roles and oauth_admin_roles:
 | |
|                 claim_data = user_data
 | |
|                 nested_claims = oauth_claim.split(".")
 | |
|                 for nested_claim in nested_claims:
 | |
|                     claim_data = claim_data.get(nested_claim, {})
 | |
| 
 | |
|                 oauth_roles = []
 | |
| 
 | |
|                 if isinstance(claim_data, list):
 | |
|                     oauth_roles = claim_data
 | |
|                 if isinstance(claim_data, str) or isinstance(claim_data, int):
 | |
|                     oauth_roles = [str(claim_data)]
 | |
| 
 | |
|             log.debug(f"Oauth Roles claim: {oauth_claim}")
 | |
|             log.debug(f"User roles from oauth: {oauth_roles}")
 | |
|             log.debug(f"Accepted user roles: {oauth_allowed_roles}")
 | |
|             log.debug(f"Accepted admin roles: {oauth_admin_roles}")
 | |
| 
 | |
|             # If any roles are found, check if they match the allowed or admin roles
 | |
|             if oauth_roles:
 | |
|                 # If role management is enabled, and matching roles are provided, use the roles
 | |
|                 for allowed_role in oauth_allowed_roles:
 | |
|                     # If the user has any of the allowed roles, assign the role "user"
 | |
|                     if allowed_role in oauth_roles:
 | |
|                         log.debug("Assigned user the user role")
 | |
|                         role = "user"
 | |
|                         break
 | |
|                 for admin_role in oauth_admin_roles:
 | |
|                     # If the user has any of the admin roles, assign the role "admin"
 | |
|                     if admin_role in oauth_roles:
 | |
|                         log.debug("Assigned user the admin role")
 | |
|                         role = "admin"
 | |
|                         break
 | |
|         else:
 | |
|             if not user:
 | |
|                 # If role management is disabled, use the default role for new users
 | |
|                 role = auth_manager_config.DEFAULT_USER_ROLE
 | |
|             else:
 | |
|                 # If role management is disabled, use the existing role for existing users
 | |
|                 role = user.role
 | |
| 
 | |
|         return role
 | |
| 
 | |
|     def update_user_groups(self, user, user_data, default_permissions):
 | |
|         log.debug("Running OAUTH Group management")
 | |
|         oauth_claim = auth_manager_config.OAUTH_GROUPS_CLAIM
 | |
| 
 | |
|         try:
 | |
|             blocked_groups = json.loads(auth_manager_config.OAUTH_BLOCKED_GROUPS)
 | |
|         except Exception as e:
 | |
|             log.exception(f"Error loading OAUTH_BLOCKED_GROUPS: {e}")
 | |
|             blocked_groups = []
 | |
| 
 | |
|         user_oauth_groups = []
 | |
|         # Nested claim search for groups claim
 | |
|         if oauth_claim:
 | |
|             claim_data = user_data
 | |
|             nested_claims = oauth_claim.split(".")
 | |
|             for nested_claim in nested_claims:
 | |
|                 claim_data = claim_data.get(nested_claim, {})
 | |
| 
 | |
|             if isinstance(claim_data, list):
 | |
|                 user_oauth_groups = claim_data
 | |
|             elif isinstance(claim_data, str):
 | |
|                 user_oauth_groups = [claim_data]
 | |
|             else:
 | |
|                 user_oauth_groups = []
 | |
| 
 | |
|         user_current_groups: list[GroupModel] = Groups.get_groups_by_member_id(user.id)
 | |
|         all_available_groups: list[GroupModel] = Groups.get_groups()
 | |
| 
 | |
|         # Create groups if they don't exist and creation is enabled
 | |
|         if auth_manager_config.ENABLE_OAUTH_GROUP_CREATION:
 | |
|             log.debug("Checking for missing groups to create...")
 | |
|             all_group_names = {g.name for g in all_available_groups}
 | |
|             groups_created = False
 | |
|             # Determine creator ID: Prefer admin, fallback to current user if no admin exists
 | |
|             admin_user = Users.get_super_admin_user()
 | |
|             creator_id = admin_user.id if admin_user else user.id
 | |
|             log.debug(f"Using creator ID {creator_id} for potential group creation.")
 | |
| 
 | |
|             for group_name in user_oauth_groups:
 | |
|                 if group_name not in all_group_names:
 | |
|                     log.info(
 | |
|                         f"Group '{group_name}' not found via OAuth claim. Creating group..."
 | |
|                     )
 | |
|                     try:
 | |
|                         new_group_form = GroupForm(
 | |
|                             name=group_name,
 | |
|                             description=f"Group '{group_name}' created automatically via OAuth.",
 | |
|                             permissions=default_permissions,  # Use default permissions from function args
 | |
|                             user_ids=[],  # Start with no users, user will be added later by subsequent logic
 | |
|                         )
 | |
|                         # Use determined creator ID (admin or fallback to current user)
 | |
|                         created_group = Groups.insert_new_group(
 | |
|                             creator_id, new_group_form
 | |
|                         )
 | |
|                         if created_group:
 | |
|                             log.info(
 | |
|                                 f"Successfully created group '{group_name}' with ID {created_group.id} using creator ID {creator_id}"
 | |
|                             )
 | |
|                             groups_created = True
 | |
|                             # Add to local set to prevent duplicate creation attempts in this run
 | |
|                             all_group_names.add(group_name)
 | |
|                         else:
 | |
|                             log.error(
 | |
|                                 f"Failed to create group '{group_name}' via OAuth."
 | |
|                             )
 | |
|                     except Exception as e:
 | |
|                         log.error(f"Error creating group '{group_name}' via OAuth: {e}")
 | |
| 
 | |
|             # Refresh the list of all available groups if any were created
 | |
|             if groups_created:
 | |
|                 all_available_groups = Groups.get_groups()
 | |
|                 log.debug("Refreshed list of all available groups after creation.")
 | |
| 
 | |
|         log.debug(f"Oauth Groups claim: {oauth_claim}")
 | |
|         log.debug(f"User oauth groups: {user_oauth_groups}")
 | |
|         log.debug(f"User's current groups: {[g.name for g in user_current_groups]}")
 | |
|         log.debug(
 | |
|             f"All groups available in OpenWebUI: {[g.name for g in all_available_groups]}"
 | |
|         )
 | |
| 
 | |
|         # Remove groups that user is no longer a part of
 | |
|         for group_model in user_current_groups:
 | |
|             if (
 | |
|                 user_oauth_groups
 | |
|                 and group_model.name not in user_oauth_groups
 | |
|                 and not is_in_blocked_groups(group_model.name, blocked_groups)
 | |
|             ):
 | |
|                 # Remove group from user
 | |
|                 log.debug(
 | |
|                     f"Removing user from group {group_model.name} as it is no longer in their oauth groups"
 | |
|                 )
 | |
| 
 | |
|                 user_ids = group_model.user_ids
 | |
|                 user_ids = [i for i in user_ids if i != user.id]
 | |
| 
 | |
|                 # In case a group is created, but perms are never assigned to the group by hitting "save"
 | |
|                 group_permissions = group_model.permissions
 | |
|                 if not group_permissions:
 | |
|                     group_permissions = default_permissions
 | |
| 
 | |
|                 update_form = GroupUpdateForm(
 | |
|                     name=group_model.name,
 | |
|                     description=group_model.description,
 | |
|                     permissions=group_permissions,
 | |
|                     user_ids=user_ids,
 | |
|                 )
 | |
|                 Groups.update_group_by_id(
 | |
|                     id=group_model.id, form_data=update_form, overwrite=False
 | |
|                 )
 | |
| 
 | |
|         # Add user to new groups
 | |
|         for group_model in all_available_groups:
 | |
|             if (
 | |
|                 user_oauth_groups
 | |
|                 and group_model.name in user_oauth_groups
 | |
|                 and not any(gm.name == group_model.name for gm in user_current_groups)
 | |
|                 and not is_in_blocked_groups(group_model.name, blocked_groups)
 | |
|             ):
 | |
|                 # Add user to group
 | |
|                 log.debug(
 | |
|                     f"Adding user to group {group_model.name} as it was found in their oauth groups"
 | |
|                 )
 | |
| 
 | |
|                 user_ids = group_model.user_ids
 | |
|                 user_ids.append(user.id)
 | |
| 
 | |
|                 # In case a group is created, but perms are never assigned to the group by hitting "save"
 | |
|                 group_permissions = group_model.permissions
 | |
|                 if not group_permissions:
 | |
|                     group_permissions = default_permissions
 | |
| 
 | |
|                 update_form = GroupUpdateForm(
 | |
|                     name=group_model.name,
 | |
|                     description=group_model.description,
 | |
|                     permissions=group_permissions,
 | |
|                     user_ids=user_ids,
 | |
|                 )
 | |
|                 Groups.update_group_by_id(
 | |
|                     id=group_model.id, form_data=update_form, overwrite=False
 | |
|                 )
 | |
| 
 | |
|     async def _process_picture_url(
 | |
|         self, picture_url: str, access_token: str = None
 | |
|     ) -> str:
 | |
|         """Process a picture URL and return a base64 encoded data URL.
 | |
| 
 | |
|         Args:
 | |
|             picture_url: The URL of the picture to process
 | |
|             access_token: Optional OAuth access token for authenticated requests
 | |
| 
 | |
|         Returns:
 | |
|             A data URL containing the base64 encoded picture, or "/user.png" if processing fails
 | |
|         """
 | |
|         if not picture_url:
 | |
|             return "/user.png"
 | |
| 
 | |
|         try:
 | |
|             get_kwargs = {}
 | |
|             if access_token:
 | |
|                 get_kwargs["headers"] = {
 | |
|                     "Authorization": f"Bearer {access_token}",
 | |
|                 }
 | |
|             async with aiohttp.ClientSession(trust_env=True) as session:
 | |
|                 async with session.get(
 | |
|                     picture_url, **get_kwargs, ssl=AIOHTTP_CLIENT_SESSION_SSL
 | |
|                 ) as resp:
 | |
|                     if resp.ok:
 | |
|                         picture = await resp.read()
 | |
|                         base64_encoded_picture = base64.b64encode(picture).decode(
 | |
|                             "utf-8"
 | |
|                         )
 | |
|                         guessed_mime_type = mimetypes.guess_type(picture_url)[0]
 | |
|                         if guessed_mime_type is None:
 | |
|                             guessed_mime_type = "image/jpeg"
 | |
|                         return (
 | |
|                             f"data:{guessed_mime_type};base64,{base64_encoded_picture}"
 | |
|                         )
 | |
|                     else:
 | |
|                         log.warning(
 | |
|                             f"Failed to fetch profile picture from {picture_url}"
 | |
|                         )
 | |
|                         return "/user.png"
 | |
|         except Exception as e:
 | |
|             log.error(f"Error processing profile picture '{picture_url}': {e}")
 | |
|             return "/user.png"
 | |
| 
 | |
|     async def handle_login(self, request, provider):
 | |
|         if provider not in OAUTH_PROVIDERS:
 | |
|             raise HTTPException(404)
 | |
|         # If the provider has a custom redirect URL, use that, otherwise automatically generate one
 | |
|         redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for(
 | |
|             "oauth_login_callback", provider=provider
 | |
|         )
 | |
|         client = self.get_client(provider)
 | |
|         if client is None:
 | |
|             raise HTTPException(404)
 | |
|         return await client.authorize_redirect(request, redirect_uri)
 | |
| 
 | |
|     async def handle_callback(self, request, provider, response):
 | |
|         if provider not in OAUTH_PROVIDERS:
 | |
|             raise HTTPException(404)
 | |
| 
 | |
|         error_message = None
 | |
|         try:
 | |
|             client = self.get_client(provider)
 | |
|             try:
 | |
|                 token = await client.authorize_access_token(request)
 | |
|             except Exception as e:
 | |
|                 log.warning(f"OAuth callback error: {e}")
 | |
|                 raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
 | |
| 
 | |
|             # Try to get userinfo from the token first, some providers include it there
 | |
|             user_data: UserInfo = token.get("userinfo")
 | |
|             if (
 | |
|                 (not user_data)
 | |
|                 or (auth_manager_config.OAUTH_EMAIL_CLAIM not in user_data)
 | |
|                 or (auth_manager_config.OAUTH_USERNAME_CLAIM not in user_data)
 | |
|             ):
 | |
|                 user_data: UserInfo = await client.userinfo(token=token)
 | |
|             if (
 | |
|                 provider == "feishu"
 | |
|                 and isinstance(user_data, dict)
 | |
|                 and "data" in user_data
 | |
|             ):
 | |
|                 user_data = user_data["data"]
 | |
|             if not user_data:
 | |
|                 log.warning(f"OAuth callback failed, user data is missing: {token}")
 | |
|                 raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
 | |
| 
 | |
|             # Extract the "sub" claim, using custom claim if configured
 | |
|             if auth_manager_config.OAUTH_SUB_CLAIM:
 | |
|                 sub = user_data.get(auth_manager_config.OAUTH_SUB_CLAIM)
 | |
|             else:
 | |
|                 # Fallback to the default sub claim if not configured
 | |
|                 sub = user_data.get(OAUTH_PROVIDERS[provider].get("sub_claim", "sub"))
 | |
|             if not sub:
 | |
|                 log.warning(f"OAuth callback failed, sub is missing: {user_data}")
 | |
|                 raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
 | |
| 
 | |
|             provider_sub = f"{provider}@{sub}"
 | |
| 
 | |
|             # Email extraction
 | |
|             email_claim = auth_manager_config.OAUTH_EMAIL_CLAIM
 | |
|             email = user_data.get(email_claim, "")
 | |
|             # We currently mandate that email addresses are provided
 | |
|             if not email:
 | |
|                 # If the provider is GitHub,and public email is not provided, we can use the access token to fetch the user's email
 | |
|                 if provider == "github":
 | |
|                     try:
 | |
|                         access_token = token.get("access_token")
 | |
|                         headers = {"Authorization": f"Bearer {access_token}"}
 | |
|                         async with aiohttp.ClientSession(trust_env=True) as session:
 | |
|                             async with session.get(
 | |
|                                 "https://api.github.com/user/emails",
 | |
|                                 headers=headers,
 | |
|                                 ssl=AIOHTTP_CLIENT_SESSION_SSL,
 | |
|                             ) as resp:
 | |
|                                 if resp.ok:
 | |
|                                     emails = await resp.json()
 | |
|                                     # use the primary email as the user's email
 | |
|                                     primary_email = next(
 | |
|                                         (
 | |
|                                             e["email"]
 | |
|                                             for e in emails
 | |
|                                             if e.get("primary")
 | |
|                                         ),
 | |
|                                         None,
 | |
|                                     )
 | |
|                                     if primary_email:
 | |
|                                         email = primary_email
 | |
|                                     else:
 | |
|                                         log.warning(
 | |
|                                             "No primary email found in GitHub response"
 | |
|                                         )
 | |
|                                         raise HTTPException(
 | |
|                                             400, detail=ERROR_MESSAGES.INVALID_CRED
 | |
|                                         )
 | |
|                                 else:
 | |
|                                     log.warning("Failed to fetch GitHub email")
 | |
|                                     raise HTTPException(
 | |
|                                         400, detail=ERROR_MESSAGES.INVALID_CRED
 | |
|                                     )
 | |
|                     except Exception as e:
 | |
|                         log.warning(f"Error fetching GitHub email: {e}")
 | |
|                         raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
 | |
|                 else:
 | |
|                     log.warning(f"OAuth callback failed, email is missing: {user_data}")
 | |
|                     raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
 | |
|             email = email.lower()
 | |
| 
 | |
|             # If allowed domains are configured, check if the email domain is in the list
 | |
|             if (
 | |
|                 "*" not in auth_manager_config.OAUTH_ALLOWED_DOMAINS
 | |
|                 and email.split("@")[-1]
 | |
|                 not in auth_manager_config.OAUTH_ALLOWED_DOMAINS
 | |
|             ):
 | |
|                 log.warning(
 | |
|                     f"OAuth callback failed, e-mail domain is not in the list of allowed domains: {user_data}"
 | |
|                 )
 | |
|                 raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
 | |
| 
 | |
|             # Check if the user exists
 | |
|             user = Users.get_user_by_oauth_sub(provider_sub)
 | |
|             if not user:
 | |
|                 # If the user does not exist, check if merging is enabled
 | |
|                 if auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL:
 | |
|                     # Check if the user exists by email
 | |
|                     user = Users.get_user_by_email(email)
 | |
|                     if user:
 | |
|                         # Update the user with the new oauth sub
 | |
|                         Users.update_user_oauth_sub_by_id(user.id, provider_sub)
 | |
| 
 | |
|             if user:
 | |
|                 determined_role = self.get_user_role(user, user_data)
 | |
|                 if user.role != determined_role:
 | |
|                     Users.update_user_role_by_id(user.id, determined_role)
 | |
|                 # Update profile picture if enabled and different from current
 | |
|                 if auth_manager_config.OAUTH_UPDATE_PICTURE_ON_LOGIN:
 | |
|                     picture_claim = auth_manager_config.OAUTH_PICTURE_CLAIM
 | |
|                     if picture_claim:
 | |
|                         new_picture_url = user_data.get(
 | |
|                             picture_claim,
 | |
|                             OAUTH_PROVIDERS[provider].get("picture_url", ""),
 | |
|                         )
 | |
|                         processed_picture_url = await self._process_picture_url(
 | |
|                             new_picture_url, token.get("access_token")
 | |
|                         )
 | |
|                         if processed_picture_url != user.profile_image_url:
 | |
|                             Users.update_user_profile_image_url_by_id(
 | |
|                                 user.id, processed_picture_url
 | |
|                             )
 | |
|                             log.debug(f"Updated profile picture for user {user.email}")
 | |
|             else:
 | |
|                 # If the user does not exist, check if signups are enabled
 | |
|                 if auth_manager_config.ENABLE_OAUTH_SIGNUP:
 | |
|                     # Check if an existing user with the same email already exists
 | |
|                     existing_user = Users.get_user_by_email(email)
 | |
|                     if existing_user:
 | |
|                         raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
 | |
| 
 | |
|                     picture_claim = auth_manager_config.OAUTH_PICTURE_CLAIM
 | |
|                     if picture_claim:
 | |
|                         picture_url = user_data.get(
 | |
|                             picture_claim,
 | |
|                             OAUTH_PROVIDERS[provider].get("picture_url", ""),
 | |
|                         )
 | |
|                         picture_url = await self._process_picture_url(
 | |
|                             picture_url, token.get("access_token")
 | |
|                         )
 | |
|                     else:
 | |
|                         picture_url = "/user.png"
 | |
|                     username_claim = auth_manager_config.OAUTH_USERNAME_CLAIM
 | |
| 
 | |
|                     name = user_data.get(username_claim)
 | |
|                     if not name:
 | |
|                         log.warning("Username claim is missing, using email as name")
 | |
|                         name = email
 | |
| 
 | |
|                     user = Auths.insert_new_auth(
 | |
|                         email=email,
 | |
|                         password=get_password_hash(
 | |
|                             str(uuid.uuid4())
 | |
|                         ),  # Random password, not used
 | |
|                         name=name,
 | |
|                         profile_image_url=picture_url,
 | |
|                         role=self.get_user_role(None, user_data),
 | |
|                         oauth_sub=provider_sub,
 | |
|                     )
 | |
| 
 | |
|                     if auth_manager_config.WEBHOOK_URL:
 | |
|                         await post_webhook(
 | |
|                             WEBUI_NAME,
 | |
|                             auth_manager_config.WEBHOOK_URL,
 | |
|                             WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
 | |
|                             {
 | |
|                                 "action": "signup",
 | |
|                                 "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
 | |
|                                 "user": user.model_dump_json(exclude_none=True),
 | |
|                             },
 | |
|                         )
 | |
|                 else:
 | |
|                     raise HTTPException(
 | |
|                         status.HTTP_403_FORBIDDEN,
 | |
|                         detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
 | |
|                     )
 | |
| 
 | |
|             jwt_token = create_token(
 | |
|                 data={"id": user.id},
 | |
|                 expires_delta=parse_duration(auth_manager_config.JWT_EXPIRES_IN),
 | |
|             )
 | |
|             if (
 | |
|                 auth_manager_config.ENABLE_OAUTH_GROUP_MANAGEMENT
 | |
|                 and user.role != "admin"
 | |
|             ):
 | |
|                 self.update_user_groups(
 | |
|                     user=user,
 | |
|                     user_data=user_data,
 | |
|                     default_permissions=request.app.state.config.USER_PERMISSIONS,
 | |
|                 )
 | |
| 
 | |
|         except Exception as e:
 | |
|             log.error(f"Error during OAuth process: {e}")
 | |
|             error_message = (
 | |
|                 e.detail
 | |
|                 if isinstance(e, HTTPException) and e.detail
 | |
|                 else ERROR_MESSAGES.DEFAULT("Error during OAuth process")
 | |
|             )
 | |
| 
 | |
|         redirect_base_url = (
 | |
|             str(request.app.state.config.WEBUI_URL or request.base_url)
 | |
|         ).rstrip("/")
 | |
|         redirect_url = f"{redirect_base_url}/auth"
 | |
| 
 | |
|         if error_message:
 | |
|             redirect_url = f"{redirect_url}?error={error_message}"
 | |
|             return RedirectResponse(url=redirect_url, headers=response.headers)
 | |
| 
 | |
|         response = RedirectResponse(url=redirect_url, headers=response.headers)
 | |
| 
 | |
|         # Set the cookie token
 | |
|         # Redirect back to the frontend with the JWT token
 | |
|         response.set_cookie(
 | |
|             key="token",
 | |
|             value=jwt_token,
 | |
|             httponly=False,  # Required for frontend access
 | |
|             samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
 | |
|             secure=WEBUI_AUTH_COOKIE_SECURE,
 | |
|         )
 | |
| 
 | |
|         # Legacy cookies for compatibility with older frontend versions
 | |
|         if ENABLE_OAUTH_ID_TOKEN_COOKIE:
 | |
|             response.set_cookie(
 | |
|                 key="oauth_id_token",
 | |
|                 value=token.get("id_token"),
 | |
|                 httponly=True,
 | |
|                 samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
 | |
|                 secure=WEBUI_AUTH_COOKIE_SECURE,
 | |
|             )
 | |
| 
 | |
|         try:
 | |
|             # Add timestamp for tracking
 | |
|             token["issued_at"] = datetime.now().timestamp()
 | |
| 
 | |
|             # Calculate expires_at if we have expires_in
 | |
|             if "expires_in" in token and "expires_at" not in token:
 | |
|                 token["expires_at"] = datetime.now().timestamp() + token["expires_in"]
 | |
| 
 | |
|             # Clean up any existing sessions for this user/provider first
 | |
|             sessions = OAuthSessions.get_sessions_by_user_id(user.id)
 | |
|             for session in sessions:
 | |
|                 if session.provider == provider:
 | |
|                     OAuthSessions.delete_session_by_id(session.id)
 | |
| 
 | |
|             session = OAuthSessions.create_session(
 | |
|                 user_id=user.id,
 | |
|                 provider=provider,
 | |
|                 token=token,
 | |
|             )
 | |
| 
 | |
|             response.set_cookie(
 | |
|                 key="oauth_session_id",
 | |
|                 value=session.id,
 | |
|                 httponly=True,
 | |
|                 samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
 | |
|                 secure=WEBUI_AUTH_COOKIE_SECURE,
 | |
|             )
 | |
| 
 | |
|             log.info(
 | |
|                 f"Stored OAuth session server-side for user {user.id}, provider {provider}"
 | |
|             )
 | |
|         except Exception as e:
 | |
|             log.error(f"Failed to store OAuth session server-side: {e}")
 | |
| 
 | |
|         return response
 |