open-webui/backend/open_webui/utils/oauth.py

1319 lines
53 KiB
Python
Raw Normal View History

import base64
import hashlib
2024-10-16 22:32:57 +08:00
import logging
import mimetypes
import sys
import urllib
import uuid
2025-05-02 18:47:02 +08:00
import json
from datetime import datetime, timedelta
2025-09-03 22:50:02 +08:00
import re
import fnmatch
import time
import secrets
from cryptography.fernet import Fernet
2025-09-03 22:50:02 +08:00
import aiohttp
2024-10-16 22:32:57 +08:00
from authlib.integrations.starlette_client import OAuth
from authlib.oidc.core import UserInfo
from fastapi import (
HTTPException,
status,
)
2024-10-16 22:32:57 +08:00
from starlette.responses import RedirectResponse
from typing import Optional
2024-12-10 16:54:13 +08:00
from open_webui.models.auths import Auths
from open_webui.models.oauth_sessions import OAuthSessions
2024-12-10 16:54:13 +08:00
from open_webui.models.users import Users
from open_webui.models.groups import Groups, GroupModel, GroupUpdateForm, GroupForm
from open_webui.config import (
DEFAULT_USER_ROLE,
ENABLE_OAUTH_SIGNUP,
OAUTH_MERGE_ACCOUNTS_BY_EMAIL,
OAUTH_PROVIDERS,
ENABLE_OAUTH_ROLE_MANAGEMENT,
ENABLE_OAUTH_GROUP_MANAGEMENT,
ENABLE_OAUTH_GROUP_CREATION,
2025-05-02 18:47:02 +08:00
OAUTH_BLOCKED_GROUPS,
OAUTH_ROLES_CLAIM,
2025-08-09 04:46:14 +08:00
OAUTH_SUB_CLAIM,
OAUTH_GROUPS_CLAIM,
OAUTH_EMAIL_CLAIM,
OAUTH_PICTURE_CLAIM,
OAUTH_USERNAME_CLAIM,
OAUTH_ALLOWED_ROLES,
2024-10-21 09:38:06 +08:00
OAUTH_ADMIN_ROLES,
OAUTH_ALLOWED_DOMAINS,
2025-05-07 00:00:35 +08:00
OAUTH_UPDATE_PICTURE_ON_LOGIN,
2024-10-21 09:38:06 +08:00
WEBHOOK_URL,
JWT_EXPIRES_IN,
AppConfig,
)
from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
2025-02-16 16:11:18 +08:00
from open_webui.env import (
2025-05-15 03:33:52 +08:00
AIOHTTP_CLIENT_SESSION_SSL,
2025-02-16 16:11:18 +08:00
WEBUI_NAME,
WEBUI_AUTH_COOKIE_SAME_SITE,
WEBUI_AUTH_COOKIE_SECURE,
ENABLE_OAUTH_ID_TOKEN_COOKIE,
OAUTH_CLIENT_INFO_ENCRYPTION_KEY,
2025-02-16 16:11:18 +08:00
)
from open_webui.utils.misc import parse_duration
2024-12-09 08:01:56 +08:00
from open_webui.utils.auth import get_password_hash, create_token
from open_webui.utils.webhook import post_webhook
from mcp.shared.auth import (
OAuthClientMetadata,
OAuthMetadata,
)
class OAuthClientInformationFull(OAuthClientMetadata):
issuer: Optional[str] = None # URL of the OAuth server that issued this client
client_id: str
client_secret: str | None = None
client_id_issued_at: int | None = None
client_secret_expires_at: int | None = None
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OAUTH"])
2024-10-16 22:32:57 +08:00
auth_manager_config = AppConfig()
auth_manager_config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
auth_manager_config.ENABLE_OAUTH_SIGNUP = ENABLE_OAUTH_SIGNUP
auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL = OAUTH_MERGE_ACCOUNTS_BY_EMAIL
auth_manager_config.ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT
auth_manager_config.ENABLE_OAUTH_GROUP_MANAGEMENT = ENABLE_OAUTH_GROUP_MANAGEMENT
auth_manager_config.ENABLE_OAUTH_GROUP_CREATION = ENABLE_OAUTH_GROUP_CREATION
2025-05-02 18:47:02 +08:00
auth_manager_config.OAUTH_BLOCKED_GROUPS = OAUTH_BLOCKED_GROUPS
2024-10-16 22:32:57 +08:00
auth_manager_config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM
2025-08-09 04:46:14 +08:00
auth_manager_config.OAUTH_SUB_CLAIM = OAUTH_SUB_CLAIM
auth_manager_config.OAUTH_GROUPS_CLAIM = OAUTH_GROUPS_CLAIM
2024-10-16 22:32:57 +08:00
auth_manager_config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM
auth_manager_config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
auth_manager_config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM
auth_manager_config.OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES
auth_manager_config.OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES
auth_manager_config.OAUTH_ALLOWED_DOMAINS = OAUTH_ALLOWED_DOMAINS
2024-10-16 22:32:57 +08:00
auth_manager_config.WEBHOOK_URL = WEBHOOK_URL
auth_manager_config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
2025-05-07 00:00:35 +08:00
auth_manager_config.OAUTH_UPDATE_PICTURE_ON_LOGIN = OAUTH_UPDATE_PICTURE_ON_LOGIN
2024-10-16 22:32:57 +08:00
FERNET = None
if len(OAUTH_CLIENT_INFO_ENCRYPTION_KEY) != 44:
key_bytes = hashlib.sha256(OAUTH_CLIENT_INFO_ENCRYPTION_KEY.encode()).digest()
OAUTH_CLIENT_INFO_ENCRYPTION_KEY = base64.urlsafe_b64encode(key_bytes)
else:
OAUTH_CLIENT_INFO_ENCRYPTION_KEY = OAUTH_CLIENT_INFO_ENCRYPTION_KEY.encode()
try:
FERNET = Fernet(OAUTH_CLIENT_INFO_ENCRYPTION_KEY)
except Exception as e:
log.error(f"Error initializing Fernet with provided key: {e}")
raise
def encrypt_token(token) -> str:
"""Encrypt OAuth tokens for storage"""
try:
token_json = json.dumps(token)
encrypted = FERNET.encrypt(token_json.encode()).decode()
return encrypted
except Exception as e:
log.error(f"Error encrypting tokens: {e}")
raise
def decrypt_token(token: str):
"""Decrypt OAuth tokens from storage"""
try:
decrypted = FERNET.decrypt(token.encode()).decode()
return json.loads(decrypted)
except Exception as e:
log.error(f"Error decrypting tokens: {e}")
raise
2025-09-03 22:50:02 +08:00
def is_in_blocked_groups(group_name: str, groups: list) -> bool:
"""
Check if a group name matches any blocked pattern.
Supports exact matches, shell-style wildcards (*, ?), and regex patterns.
Args:
group_name: The group name to check
groups: List of patterns to match against
Returns:
True if the group is blocked, False otherwise
"""
if not groups:
return False
for group_pattern in groups:
if not group_pattern: # Skip empty patterns
continue
# Exact match
if group_name == group_pattern:
return True
# Try as regex pattern first if it contains regex-specific characters
if any(
char in group_pattern
for char in ["^", "$", "[", "]", "(", ")", "{", "}", "+", "\\", "|"]
):
try:
# Use the original pattern as-is for regex matching
if re.search(group_pattern, group_name):
return True
except re.error:
# If regex is invalid, fall through to wildcard check
pass
# Shell-style wildcard match (supports * and ?)
if "*" in group_pattern or "?" in group_pattern:
if fnmatch.fnmatch(group_name, group_pattern):
return True
return False
def get_parsed_and_base_url(server_url) -> tuple[urllib.parse.ParseResult, str]:
parsed = urllib.parse.urlparse(server_url)
base_url = f"{parsed.scheme}://{parsed.netloc}"
return parsed, base_url
def get_discovery_urls(server_url) -> list[str]:
urls = []
parsed, base_url = get_parsed_and_base_url(server_url)
urls.append(
urllib.parse.urljoin(base_url, "/.well-known/oauth-authorization-server")
)
urls.append(urllib.parse.urljoin(base_url, "/.well-known/openid-configuration"))
return urls
# TODO: Some OAuth providers require Initial Access Tokens (IATs) for dynamic client registration.
# This is not currently supported.
async def get_oauth_client_info_with_dynamic_client_registration(
request, oauth_server_url, oauth_server_key: Optional[str] = None
) -> OAuthClientInformationFull:
try:
oauth_server_metadata = None
oauth_server_metadata_url = None
redirect_base_url = (
str(request.app.state.config.WEBUI_URL or request.base_url)
).rstrip("/")
oauth_client_metadata = OAuthClientMetadata(
client_name="Open WebUI",
redirect_uris=[f"{redirect_base_url}/oauth/callback"],
grant_types=["authorization_code", "refresh_token"],
response_types=["code"],
token_endpoint_auth_method="client_secret_post",
)
# Attempt to fetch OAuth server metadata to get registration endpoint & scopes
discovery_urls = get_discovery_urls(oauth_server_url)
for url in discovery_urls:
async with aiohttp.ClientSession() as session:
async with session.get(
url, ssl=AIOHTTP_CLIENT_SESSION_SSL
) as oauth_server_metadata_response:
if oauth_server_metadata_response.status == 200:
try:
oauth_server_metadata = OAuthMetadata.model_validate(
await oauth_server_metadata_response.json()
)
oauth_server_metadata_url = url
if (
oauth_client_metadata.scope is None
and oauth_server_metadata.scopes_supported is not None
):
oauth_client_metadata.scope = " ".join(
oauth_server_metadata.scopes_supported
)
break
except Exception as e:
log.error(f"Error parsing OAuth metadata from {url}: {e}")
continue
registration_url = None
if oauth_server_metadata and oauth_server_metadata.registration_endpoint:
registration_url = str(oauth_server_metadata.registration_endpoint)
else:
_, base_url = get_parsed_and_base_url(oauth_server_url)
registration_url = urllib.parse.urljoin(base_url, "/register")
registration_data = oauth_client_metadata.model_dump(
exclude_none=True,
mode="json",
by_alias=True,
)
# Perform dynamic client registration and return client info
async with aiohttp.ClientSession() as session:
async with session.post(
registration_url, json=registration_data, ssl=AIOHTTP_CLIENT_SESSION_SSL
) as oauth_client_registration_response:
try:
registration_response_json = (
await oauth_client_registration_response.json()
)
oauth_client_info = OAuthClientInformationFull.model_validate(
{
**registration_response_json,
**{"issuer": oauth_server_metadata_url},
}
)
log.info(
f"Dynamic client registration successful at {registration_url}, client_id: {oauth_client_info.client_id}"
)
return oauth_client_info
except Exception as e:
error_text = None
try:
error_text = await oauth_client_registration_response.text()
log.error(
f"Dynamic client registration failed at {registration_url}: {oauth_client_registration_response.status} - {error_text}"
)
except Exception as e:
pass
log.error(f"Error parsing client registration response: {e}")
raise Exception(
f"Dynamic client registration failed: {error_text}"
if error_text
else "Error parsing client registration response"
)
raise Exception("Dynamic client registration failed")
except Exception as e:
log.error(f"Exception during dynamic client registration: {e}")
raise e
class OAuthClientManager:
def __init__(self, app):
self.oauth = OAuth()
self.app = app
self.clients = {}
def add_client(self, client_id, oauth_client_info: OAuthClientInformationFull):
if client_id not in self.clients:
self.clients[client_id] = {
"client": self.oauth.register(
name=client_id,
client_id=oauth_client_info.client_id,
client_secret=oauth_client_info.client_secret,
client_kwargs=(
{"scope": oauth_client_info.scope}
if oauth_client_info.scope
else {}
),
server_metadata_url=(
oauth_client_info.issuer if oauth_client_info.issuer else None
),
),
"client_info": oauth_client_info,
}
return self.clients[client_id]
def remove_client(self, client_id):
if client_id in self.clients:
del self.clients[client_id]
log.info(f"Removed OAuth client {client_id}")
return True
def get_client(self, client_id):
client = self.clients.get(client_id)
return client["client"] if client else None
def get_client_info(self, client_id):
client = self.clients.get(client_id)
return client["client_info"] if client else None
def get_server_metadata_url(self, client_id):
if client_id in self.clients:
client = self.clients[client_id]
return (
client.server_metadata_url
if hasattr(client, "server_metadata_url")
else None
)
return None
async def get_oauth_token(
self, user_id: str, session_id: str, force_refresh: bool = False
):
"""
Get a valid OAuth token for the user, automatically refreshing if needed.
Args:
user_id: The user ID
session_id: The OAuth session ID
force_refresh: Force token refresh even if current token appears valid
Returns:
dict: OAuth token data with access_token, or None if no valid token available
"""
try:
# Get the OAuth session
session = OAuthSessions.get_session_by_id_and_user_id(session_id, user_id)
if not session:
log.warning(
f"No OAuth session found for user {user_id}, session {session_id}"
)
return None
if force_refresh or datetime.now() + timedelta(
minutes=5
) >= datetime.fromtimestamp(session.expires_at):
log.debug(
f"Token refresh needed for user {user_id}, client_id {session.provider}"
)
refreshed_token = await self._refresh_token(session)
if refreshed_token:
return refreshed_token
else:
log.warning(
f"Token refresh failed for user {user_id}, client_id {session.provider}"
)
return None
return session.token
except Exception as e:
log.error(f"Error getting OAuth token for user {user_id}: {e}")
return None
async def _refresh_token(self, session) -> dict:
"""
Refresh an OAuth token if needed, with concurrency protection.
Args:
session: The OAuth session object
Returns:
dict: Refreshed token data, or None if refresh failed
"""
try:
# Perform the actual refresh
refreshed_token = await self._perform_token_refresh(session)
if refreshed_token:
# Update the session with new token data
session = OAuthSessions.update_session_by_id(
session.id, refreshed_token
)
log.info(f"Successfully refreshed token for session {session.id}")
return session.token
else:
log.error(f"Failed to refresh token for session {session.id}")
return None
except Exception as e:
log.error(f"Error refreshing token for session {session.id}: {e}")
return None
async def _perform_token_refresh(self, session) -> dict:
"""
Perform the actual OAuth token refresh.
Args:
session: The OAuth session object
Returns:
dict: New token data, or None if refresh failed
"""
client_id = session.provider
token_data = session.token
if not token_data.get("refresh_token"):
log.warning(f"No refresh token available for session {session.id}")
return None
try:
client = self.get_client(client_id)
if not client:
log.error(f"No OAuth client found for provider {client_id}")
return None
token_endpoint = None
async with aiohttp.ClientSession(trust_env=True) as session_http:
async with session_http.get(
self.get_server_metadata_url(client_id)
) as r:
if r.status == 200:
openid_data = await r.json()
token_endpoint = openid_data.get("token_endpoint")
else:
log.error(
f"Failed to fetch OpenID configuration for client_id {client_id}"
)
if not token_endpoint:
log.error(f"No token endpoint found for client_id {client_id}")
return None
# Prepare refresh request
refresh_data = {
"grant_type": "refresh_token",
"refresh_token": token_data["refresh_token"],
"client_id": client.client_id,
}
if hasattr(client, "client_secret") and client.client_secret:
refresh_data["client_secret"] = client.client_secret
# Make refresh request
async with aiohttp.ClientSession(trust_env=True) as session_http:
async with session_http.post(
token_endpoint,
data=refresh_data,
headers={"Content-Type": "application/x-www-form-urlencoded"},
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as r:
if r.status == 200:
new_token_data = await r.json()
# Merge with existing token data (preserve refresh_token if not provided)
if "refresh_token" not in new_token_data:
new_token_data["refresh_token"] = token_data[
"refresh_token"
]
# Add timestamp for tracking
new_token_data["issued_at"] = datetime.now().timestamp()
# Calculate expires_at if we have expires_in
if (
"expires_in" in new_token_data
and "expires_at" not in new_token_data
):
new_token_data["expires_at"] = int(
datetime.now().timestamp()
+ new_token_data["expires_in"]
)
log.debug(f"Token refresh successful for client_id {client_id}")
return new_token_data
else:
error_text = await r.text()
log.error(
f"Token refresh failed for client_id {client_id}: {r.status} - {error_text}"
)
return None
except Exception as e:
log.error(f"Exception during token refresh for client_id {client_id}: {e}")
return None
async def handle_authorize(self, request, client_id: str) -> RedirectResponse:
client = self.get_client(client_id)
if client is None:
raise HTTPException(404)
client_info = self.get_client_info(client_id)
if client_info is None:
raise HTTPException(404)
redirect_uri = (
client_info.redirect_uris[0] if client_info.redirect_uris else None
)
return await client.authorize_redirect(request, redirect_uri)
async def handle_callback(self, request, client_id: str, user_id: str, response):
client = self.get_client(client_id)
if client is None:
raise HTTPException(404)
error_message = None
try:
token = await client.authorize_access_token(request)
if token:
try:
# Add timestamp for tracking
token["issued_at"] = datetime.now().timestamp()
# Calculate expires_at if we have expires_in
if "expires_in" in token and "expires_at" not in token:
token["expires_at"] = (
datetime.now().timestamp() + token["expires_in"]
)
# Clean up any existing sessions for this user/client_id first
sessions = OAuthSessions.get_sessions_by_user_id(user_id)
for session in sessions:
if session.provider == client_id:
OAuthSessions.delete_session_by_id(session.id)
session = OAuthSessions.create_session(
user_id=user_id,
provider=client_id,
token=token,
)
log.info(
f"Stored OAuth session server-side for user {user_id}, client_id {client_id}"
)
except Exception as e:
error_message = "Failed to store OAuth session server-side"
log.error(f"Failed to store OAuth session server-side: {e}")
else:
error_message = "Failed to obtain OAuth token"
log.warning(error_message)
except Exception as e:
error_message = "OAuth callback error"
log.warning(f"OAuth callback error: {e}")
redirect_base_url = (
str(request.app.state.config.WEBUI_URL or request.base_url)
).rstrip("/")
redirect_url = f"{redirect_base_url}/auth"
if error_message:
redirect_url = f"{redirect_url}?error={error_message}"
return RedirectResponse(url=redirect_url, headers=response.headers)
response = RedirectResponse(url=redirect_url, headers=response.headers)
2024-10-16 22:32:57 +08:00
class OAuthManager:
2025-02-16 16:11:18 +08:00
def __init__(self, app):
2024-10-16 22:32:57 +08:00
self.oauth = OAuth()
2025-02-16 16:11:18 +08:00
self.app = app
self._clients = {}
for _, provider_config in OAUTH_PROVIDERS.items():
provider_config["register"](self.oauth)
2024-10-16 22:32:57 +08:00
def get_client(self, provider_name):
if provider_name not in self._clients:
self._clients[provider_name] = self.oauth.create_client(provider_name)
return self._clients[provider_name]
def get_server_metadata_url(self, provider_name):
if provider_name in self._clients:
client = self._clients[provider_name]
return (
client.server_metadata_url
if hasattr(client, "server_metadata_url")
else None
)
return None
2025-09-19 13:10:48 +08:00
async def get_oauth_token(
self, user_id: str, session_id: str, force_refresh: bool = False
):
"""
Get a valid OAuth token for the user, automatically refreshing if needed.
Args:
user_id: The user ID
provider: Optional provider name. If None, gets the most recent session.
force_refresh: Force token refresh even if current token appears valid
Returns:
dict: OAuth token data with access_token, or None if no valid token available
"""
try:
# Get the OAuth session
session = OAuthSessions.get_session_by_id_and_user_id(session_id, user_id)
if not session:
log.warning(
f"No OAuth session found for user {user_id}, session {session_id}"
)
return None
if force_refresh or datetime.now() + timedelta(
minutes=5
) >= datetime.fromtimestamp(session.expires_at):
log.debug(
f"Token refresh needed for user {user_id}, provider {session.provider}"
)
2025-09-19 13:10:48 +08:00
refreshed_token = await self._refresh_token(session)
if refreshed_token:
return refreshed_token
else:
log.warning(
f"Token refresh failed for user {user_id}, provider {session.provider}"
)
return None
return session.token
except Exception as e:
log.error(f"Error getting OAuth token for user {user_id}: {e}")
return None
async def _refresh_token(self, session) -> dict:
"""
Refresh an OAuth token if needed, with concurrency protection.
Args:
session: The OAuth session object
Returns:
dict: Refreshed token data, or None if refresh failed
"""
try:
# Perform the actual refresh
refreshed_token = await self._perform_token_refresh(session)
if refreshed_token:
# Update the session with new token data
session = OAuthSessions.update_session_by_id(
session.id, refreshed_token
)
log.info(f"Successfully refreshed token for session {session.id}")
return session.token
else:
log.error(f"Failed to refresh token for session {session.id}")
return None
except Exception as e:
log.error(f"Error refreshing token for session {session.id}: {e}")
return None
async def _perform_token_refresh(self, session) -> dict:
"""
Perform the actual OAuth token refresh.
Args:
session: The OAuth session object
Returns:
dict: New token data, or None if refresh failed
"""
provider = session.provider
token_data = session.token
if not token_data.get("refresh_token"):
log.warning(f"No refresh token available for session {session.id}")
return None
try:
client = self.get_client(provider)
if not client:
log.error(f"No OAuth client found for provider {provider}")
return None
2025-09-24 19:56:24 +08:00
server_metadata_url = self.get_server_metadata_url(provider)
token_endpoint = None
async with aiohttp.ClientSession(trust_env=True) as session_http:
2025-09-24 19:56:24 +08:00
async with session_http.get(server_metadata_url) as r:
if r.status == 200:
openid_data = await r.json()
token_endpoint = openid_data.get("token_endpoint")
else:
log.error(
f"Failed to fetch OpenID configuration for provider {provider}"
)
if not token_endpoint:
log.error(f"No token endpoint found for provider {provider}")
return None
# Prepare refresh request
refresh_data = {
"grant_type": "refresh_token",
"refresh_token": token_data["refresh_token"],
"client_id": client.client_id,
}
# Add client_secret if available (some providers require it)
if hasattr(client, "client_secret") and client.client_secret:
refresh_data["client_secret"] = client.client_secret
# Make refresh request
async with aiohttp.ClientSession(trust_env=True) as session_http:
async with session_http.post(
token_endpoint,
data=refresh_data,
headers={"Content-Type": "application/x-www-form-urlencoded"},
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as r:
if r.status == 200:
new_token_data = await r.json()
# Merge with existing token data (preserve refresh_token if not provided)
if "refresh_token" not in new_token_data:
new_token_data["refresh_token"] = token_data[
"refresh_token"
]
# Add timestamp for tracking
new_token_data["issued_at"] = datetime.now().timestamp()
# Calculate expires_at if we have expires_in
if (
"expires_in" in new_token_data
and "expires_at" not in new_token_data
):
2025-09-24 19:56:50 +08:00
new_token_data["expires_at"] = int(
datetime.now().timestamp()
+ new_token_data["expires_in"]
)
log.debug(f"Token refresh successful for provider {provider}")
return new_token_data
else:
error_text = await r.text()
log.error(
f"Token refresh failed for provider {provider}: {r.status} - {error_text}"
)
return None
except Exception as e:
log.error(f"Exception during token refresh for provider {provider}: {e}")
return None
2024-10-16 22:32:57 +08:00
def get_user_role(self, user, user_data):
user_count = Users.get_num_users()
if user and user_count == 1:
2024-10-16 22:32:57 +08:00
# If the user is the only user, assign the role "admin" - actually repairs role for single user on login
log.debug("Assigning the only user the admin role")
2024-10-16 22:32:57 +08:00
return "admin"
if not user and user_count == 0:
2024-10-16 22:32:57 +08:00
# If there are no users, assign the role "admin", as the first user will be an admin
log.debug("Assigning the first user the admin role")
2024-10-16 22:32:57 +08:00
return "admin"
if auth_manager_config.ENABLE_OAUTH_ROLE_MANAGEMENT:
log.debug("Running OAUTH Role management")
2024-10-16 22:32:57 +08:00
oauth_claim = auth_manager_config.OAUTH_ROLES_CLAIM
oauth_allowed_roles = auth_manager_config.OAUTH_ALLOWED_ROLES
oauth_admin_roles = auth_manager_config.OAUTH_ADMIN_ROLES
2025-03-10 17:42:59 +08:00
oauth_roles = []
2025-01-31 22:05:33 +08:00
# Default/fallback role if no matching roles are found
role = auth_manager_config.DEFAULT_USER_ROLE
2024-10-16 22:32:57 +08:00
# Next block extracts the roles from the user data, accepting nested claims of any depth
if oauth_claim and oauth_allowed_roles and oauth_admin_roles:
claim_data = user_data
nested_claims = oauth_claim.split(".")
for nested_claim in nested_claims:
claim_data = claim_data.get(nested_claim, {})
2025-08-18 23:49:29 +08:00
oauth_roles = []
if isinstance(claim_data, list):
oauth_roles = claim_data
if isinstance(claim_data, str) or isinstance(claim_data, int):
oauth_roles = [str(claim_data)]
2024-10-16 22:32:57 +08:00
log.debug(f"Oauth Roles claim: {oauth_claim}")
log.debug(f"User roles from oauth: {oauth_roles}")
log.debug(f"Accepted user roles: {oauth_allowed_roles}")
log.debug(f"Accepted admin roles: {oauth_admin_roles}")
2024-10-16 22:32:57 +08:00
# If any roles are found, check if they match the allowed or admin roles
if oauth_roles:
# If role management is enabled, and matching roles are provided, use the roles
for allowed_role in oauth_allowed_roles:
# If the user has any of the allowed roles, assign the role "user"
if allowed_role in oauth_roles:
log.debug("Assigned user the user role")
2024-10-16 22:32:57 +08:00
role = "user"
break
for admin_role in oauth_admin_roles:
# If the user has any of the admin roles, assign the role "admin"
if admin_role in oauth_roles:
log.debug("Assigned user the admin role")
2024-10-16 22:32:57 +08:00
role = "admin"
break
else:
2024-10-16 22:32:57 +08:00
if not user:
# If role management is disabled, use the default role for new users
role = auth_manager_config.DEFAULT_USER_ROLE
else:
# If role management is disabled, use the existing role for existing users
role = user.role
return role
2024-12-18 05:51:29 +08:00
def update_user_groups(self, user, user_data, default_permissions):
log.debug("Running OAUTH Group management")
oauth_claim = auth_manager_config.OAUTH_GROUPS_CLAIM
2025-05-02 18:47:02 +08:00
try:
blocked_groups = json.loads(auth_manager_config.OAUTH_BLOCKED_GROUPS)
except Exception as e:
log.exception(f"Error loading OAUTH_BLOCKED_GROUPS: {e}")
blocked_groups = []
2025-03-10 17:42:59 +08:00
user_oauth_groups = []
# Nested claim search for groups claim
if oauth_claim:
claim_data = user_data
nested_claims = oauth_claim.split(".")
for nested_claim in nested_claims:
claim_data = claim_data.get(nested_claim, {})
2025-05-07 06:01:03 +08:00
if isinstance(claim_data, list):
user_oauth_groups = claim_data
elif isinstance(claim_data, str):
user_oauth_groups = [claim_data]
else:
user_oauth_groups = []
user_current_groups: list[GroupModel] = Groups.get_groups_by_member_id(user.id)
all_available_groups: list[GroupModel] = Groups.get_groups()
# Create groups if they don't exist and creation is enabled
if auth_manager_config.ENABLE_OAUTH_GROUP_CREATION:
log.debug("Checking for missing groups to create...")
all_group_names = {g.name for g in all_available_groups}
groups_created = False
# Determine creator ID: Prefer admin, fallback to current user if no admin exists
2025-05-05 23:38:36 +08:00
admin_user = Users.get_super_admin_user()
creator_id = admin_user.id if admin_user else user.id
log.debug(f"Using creator ID {creator_id} for potential group creation.")
for group_name in user_oauth_groups:
if group_name not in all_group_names:
2025-04-23 15:05:15 +08:00
log.info(
f"Group '{group_name}' not found via OAuth claim. Creating group..."
)
try:
new_group_form = GroupForm(
name=group_name,
description=f"Group '{group_name}' created automatically via OAuth.",
2025-04-23 15:05:15 +08:00
permissions=default_permissions, # Use default permissions from function args
user_ids=[], # Start with no users, user will be added later by subsequent logic
)
# Use determined creator ID (admin or fallback to current user)
2025-04-23 15:05:15 +08:00
created_group = Groups.insert_new_group(
creator_id, new_group_form
)
if created_group:
2025-04-23 15:05:15 +08:00
log.info(
f"Successfully created group '{group_name}' with ID {created_group.id} using creator ID {creator_id}"
)
groups_created = True
# Add to local set to prevent duplicate creation attempts in this run
all_group_names.add(group_name)
else:
2025-04-23 15:05:15 +08:00
log.error(
f"Failed to create group '{group_name}' via OAuth."
)
except Exception as e:
log.error(f"Error creating group '{group_name}' via OAuth: {e}")
# Refresh the list of all available groups if any were created
if groups_created:
all_available_groups = Groups.get_groups()
log.debug("Refreshed list of all available groups after creation.")
log.debug(f"Oauth Groups claim: {oauth_claim}")
log.debug(f"User oauth groups: {user_oauth_groups}")
log.debug(f"User's current groups: {[g.name for g in user_current_groups]}")
2025-02-10 14:20:47 +08:00
log.debug(
f"All groups available in OpenWebUI: {[g.name for g in all_available_groups]}"
)
# Remove groups that user is no longer a part of
for group_model in user_current_groups:
2025-05-02 18:47:02 +08:00
if (
user_oauth_groups
and group_model.name not in user_oauth_groups
2025-09-03 22:50:02 +08:00
and not is_in_blocked_groups(group_model.name, blocked_groups)
2025-05-02 18:47:02 +08:00
):
# Remove group from user
2025-02-10 14:20:47 +08:00
log.debug(
f"Removing user from group {group_model.name} as it is no longer in their oauth groups"
)
user_ids = group_model.user_ids
user_ids = [i for i in user_ids if i != user.id]
# In case a group is created, but perms are never assigned to the group by hitting "save"
group_permissions = group_model.permissions
if not group_permissions:
group_permissions = default_permissions
2024-12-18 05:51:29 +08:00
update_form = GroupUpdateForm(
name=group_model.name,
description=group_model.description,
permissions=group_permissions,
user_ids=user_ids,
)
Groups.update_group_by_id(
id=group_model.id, form_data=update_form, overwrite=False
)
# Add user to new groups
for group_model in all_available_groups:
2025-03-10 17:42:59 +08:00
if (
user_oauth_groups
and group_model.name in user_oauth_groups
and not any(gm.name == group_model.name for gm in user_current_groups)
2025-09-03 22:50:02 +08:00
and not is_in_blocked_groups(group_model.name, blocked_groups)
2024-12-18 05:51:29 +08:00
):
# Add user to group
2025-02-10 14:20:47 +08:00
log.debug(
f"Adding user to group {group_model.name} as it was found in their oauth groups"
)
user_ids = group_model.user_ids
user_ids.append(user.id)
# In case a group is created, but perms are never assigned to the group by hitting "save"
group_permissions = group_model.permissions
if not group_permissions:
group_permissions = default_permissions
2024-12-18 05:51:29 +08:00
update_form = GroupUpdateForm(
name=group_model.name,
description=group_model.description,
permissions=group_permissions,
user_ids=user_ids,
)
Groups.update_group_by_id(
id=group_model.id, form_data=update_form, overwrite=False
)
2024-10-16 22:32:57 +08:00
2025-05-07 00:00:35 +08:00
async def _process_picture_url(
self, picture_url: str, access_token: str = None
) -> str:
"""Process a picture URL and return a base64 encoded data URL.
Args:
picture_url: The URL of the picture to process
access_token: Optional OAuth access token for authenticated requests
Returns:
A data URL containing the base64 encoded picture, or "/user.png" if processing fails
"""
if not picture_url:
return "/user.png"
try:
get_kwargs = {}
if access_token:
get_kwargs["headers"] = {
"Authorization": f"Bearer {access_token}",
}
2025-05-15 03:27:34 +08:00
async with aiohttp.ClientSession(trust_env=True) as session:
2025-05-15 03:33:52 +08:00
async with session.get(
picture_url, **get_kwargs, ssl=AIOHTTP_CLIENT_SESSION_SSL
) as resp:
2025-05-07 00:00:35 +08:00
if resp.ok:
picture = await resp.read()
base64_encoded_picture = base64.b64encode(picture).decode(
"utf-8"
)
guessed_mime_type = mimetypes.guess_type(picture_url)[0]
if guessed_mime_type is None:
guessed_mime_type = "image/jpeg"
return (
f"data:{guessed_mime_type};base64,{base64_encoded_picture}"
)
else:
log.warning(
f"Failed to fetch profile picture from {picture_url}"
)
return "/user.png"
except Exception as e:
log.error(f"Error processing profile picture '{picture_url}': {e}")
return "/user.png"
2025-02-16 16:11:18 +08:00
async def handle_login(self, request, provider):
2024-10-16 22:32:57 +08:00
if provider not in OAUTH_PROVIDERS:
raise HTTPException(404)
# If the provider has a custom redirect URL, use that, otherwise automatically generate one
redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for(
"oauth_callback", provider=provider
)
client = self.get_client(provider)
if client is None:
raise HTTPException(404)
return await client.authorize_redirect(request, redirect_uri)
2025-02-16 16:11:18 +08:00
async def handle_callback(self, request, provider, response):
2024-10-16 22:32:57 +08:00
if provider not in OAUTH_PROVIDERS:
raise HTTPException(404)
error_message = None
2024-10-16 22:32:57 +08:00
try:
client = self.get_client(provider)
try:
token = await client.authorize_access_token(request)
except Exception as e:
log.warning(f"OAuth callback error: {e}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
2025-09-08 18:36:00 +08:00
# Try to get userinfo from the token first, some providers include it there
user_data: UserInfo = token.get("userinfo")
if (
(not user_data)
or (auth_manager_config.OAUTH_EMAIL_CLAIM not in user_data)
or (auth_manager_config.OAUTH_USERNAME_CLAIM not in user_data)
):
user_data: UserInfo = await client.userinfo(token=token)
2025-09-17 00:16:08 +08:00
if (
provider == "feishu"
and isinstance(user_data, dict)
and "data" in user_data
):
user_data = user_data["data"]
if not user_data:
log.warning(f"OAuth callback failed, user data is missing: {token}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
2024-10-16 22:32:57 +08:00
2025-09-08 18:36:00 +08:00
# Extract the "sub" claim, using custom claim if configured
if auth_manager_config.OAUTH_SUB_CLAIM:
sub = user_data.get(auth_manager_config.OAUTH_SUB_CLAIM)
else:
# Fallback to the default sub claim if not configured
sub = user_data.get(OAUTH_PROVIDERS[provider].get("sub_claim", "sub"))
if not sub:
log.warning(f"OAuth callback failed, sub is missing: {user_data}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
provider_sub = f"{provider}@{sub}"
2025-09-08 18:36:00 +08:00
# Email extraction
email_claim = auth_manager_config.OAUTH_EMAIL_CLAIM
email = user_data.get(email_claim, "")
# We currently mandate that email addresses are provided
if not email:
# If the provider is GitHub,and public email is not provided, we can use the access token to fetch the user's email
if provider == "github":
try:
access_token = token.get("access_token")
headers = {"Authorization": f"Bearer {access_token}"}
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(
"https://api.github.com/user/emails",
headers=headers,
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as resp:
if resp.ok:
emails = await resp.json()
# use the primary email as the user's email
primary_email = next(
(
e["email"]
for e in emails
if e.get("primary")
),
None,
2025-02-20 17:01:29 +08:00
)
if primary_email:
email = primary_email
else:
log.warning(
"No primary email found in GitHub response"
)
raise HTTPException(
400, detail=ERROR_MESSAGES.INVALID_CRED
)
else:
log.warning("Failed to fetch GitHub email")
2025-02-20 17:01:29 +08:00
raise HTTPException(
400, detail=ERROR_MESSAGES.INVALID_CRED
)
except Exception as e:
log.warning(f"Error fetching GitHub email: {e}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
else:
log.warning(f"OAuth callback failed, email is missing: {user_data}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
email = email.lower()
2025-09-08 18:36:00 +08:00
# If allowed domains are configured, check if the email domain is in the list
if (
"*" not in auth_manager_config.OAUTH_ALLOWED_DOMAINS
and email.split("@")[-1]
not in auth_manager_config.OAUTH_ALLOWED_DOMAINS
):
log.warning(
f"OAuth callback failed, e-mail domain is not in the list of allowed domains: {user_data}"
)
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
# Check if the user exists
user = Users.get_user_by_oauth_sub(provider_sub)
if not user:
# If the user does not exist, check if merging is enabled
if auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL:
# Check if the user exists by email
user = Users.get_user_by_email(email)
if user:
# Update the user with the new oauth sub
Users.update_user_oauth_sub_by_id(user.id, provider_sub)
if user:
determined_role = self.get_user_role(user, user_data)
if user.role != determined_role:
Users.update_user_role_by_id(user.id, determined_role)
# Update profile picture if enabled and different from current
if auth_manager_config.OAUTH_UPDATE_PICTURE_ON_LOGIN:
picture_claim = auth_manager_config.OAUTH_PICTURE_CLAIM
if picture_claim:
new_picture_url = user_data.get(
picture_claim,
OAUTH_PROVIDERS[provider].get("picture_url", ""),
2025-05-07 00:00:35 +08:00
)
processed_picture_url = await self._process_picture_url(
new_picture_url, token.get("access_token")
)
if processed_picture_url != user.profile_image_url:
Users.update_user_profile_image_url_by_id(
user.id, processed_picture_url
)
log.debug(f"Updated profile picture for user {user.email}")
2025-09-08 18:36:00 +08:00
else:
# If the user does not exist, check if signups are enabled
if auth_manager_config.ENABLE_OAUTH_SIGNUP:
# Check if an existing user with the same email already exists
existing_user = Users.get_user_by_email(email)
if existing_user:
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
picture_claim = auth_manager_config.OAUTH_PICTURE_CLAIM
if picture_claim:
picture_url = user_data.get(
picture_claim,
OAUTH_PROVIDERS[provider].get("picture_url", ""),
)
picture_url = await self._process_picture_url(
picture_url, token.get("access_token")
)
else:
picture_url = "/user.png"
username_claim = auth_manager_config.OAUTH_USERNAME_CLAIM
name = user_data.get(username_claim)
if not name:
log.warning("Username claim is missing, using email as name")
name = email
user = Auths.insert_new_auth(
email=email,
password=get_password_hash(
str(uuid.uuid4())
), # Random password, not used
name=name,
profile_image_url=picture_url,
2025-09-08 18:36:00 +08:00
role=self.get_user_role(None, user_data),
oauth_sub=provider_sub,
2025-05-07 00:00:35 +08:00
)
2024-10-16 22:32:57 +08:00
if auth_manager_config.WEBHOOK_URL:
await post_webhook(
WEBUI_NAME,
auth_manager_config.WEBHOOK_URL,
WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
{
"action": "signup",
"message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
"user": user.model_dump_json(exclude_none=True),
},
)
else:
raise HTTPException(
status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
2024-10-16 22:32:57 +08:00
)
jwt_token = create_token(
data={"id": user.id},
expires_delta=parse_duration(auth_manager_config.JWT_EXPIRES_IN),
)
if (
auth_manager_config.ENABLE_OAUTH_GROUP_MANAGEMENT
and user.role != "admin"
):
self.update_user_groups(
user=user,
user_data=user_data,
default_permissions=request.app.state.config.USER_PERMISSIONS,
)
except Exception as e:
log.error(f"Error during OAuth process: {e}")
error_message = (
e.detail
if isinstance(e, HTTPException) and e.detail
else ERROR_MESSAGES.DEFAULT("Error during OAuth process")
2024-12-18 05:51:29 +08:00
)
redirect_base_url = (
str(request.app.state.config.WEBUI_URL or request.base_url)
).rstrip("/")
2025-08-14 06:00:38 +08:00
redirect_url = f"{redirect_base_url}/auth"
if error_message:
redirect_url = f"{redirect_url}?error={error_message}"
return RedirectResponse(url=redirect_url, headers=response.headers)
2025-08-14 06:00:38 +08:00
response = RedirectResponse(url=redirect_url, headers=response.headers)
2024-10-16 22:32:57 +08:00
# Set the cookie token
2025-08-14 06:00:38 +08:00
# Redirect back to the frontend with the JWT token
2024-10-16 22:32:57 +08:00
response.set_cookie(
key="token",
value=jwt_token,
2025-08-07 01:02:54 +08:00
httponly=False, # Required for frontend access
samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
secure=WEBUI_AUTH_COOKIE_SECURE,
2024-10-16 22:32:57 +08:00
)
# Legacy cookies for compatibility with older frontend versions
if ENABLE_OAUTH_ID_TOKEN_COOKIE:
response.set_cookie(
key="oauth_id_token",
value=token.get("id_token"),
httponly=True,
samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
secure=WEBUI_AUTH_COOKIE_SECURE,
)
2025-09-08 18:18:25 +08:00
try:
# Add timestamp for tracking
token["issued_at"] = datetime.now().timestamp()
2025-09-01 03:42:34 +08:00
# Calculate expires_at if we have expires_in
if "expires_in" in token and "expires_at" not in token:
token["expires_at"] = datetime.now().timestamp() + token["expires_in"]
2025-09-08 22:52:59 +08:00
# Clean up any existing sessions for this user/provider first
sessions = OAuthSessions.get_sessions_by_user_id(user.id)
for session in sessions:
if session.provider == provider:
OAuthSessions.delete_session_by_id(session.id)
2025-09-08 22:17:11 +08:00
session = OAuthSessions.create_session(
user_id=user.id,
provider=provider,
token=token,
)
response.set_cookie(
key="oauth_session_id",
2025-09-08 22:17:11 +08:00
value=session.id,
httponly=True,
samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
secure=WEBUI_AUTH_COOKIE_SECURE,
)
log.info(
f"Stored OAuth session server-side for user {user.id}, provider {provider}"
)
except Exception as e:
log.error(f"Failed to store OAuth session server-side: {e}")
2025-09-08 18:18:25 +08:00
2025-08-14 06:00:38 +08:00
return response