feat: oauth2.1 mcp integration

This commit is contained in:
Timothy Jaeryang Baek 2025-09-25 01:49:16 -05:00
parent 972be4eda5
commit 77e971dd9f
10 changed files with 248 additions and 53 deletions

View File

@ -473,7 +473,12 @@ from open_webui.utils.auth import (
get_verified_user,
)
from open_webui.utils.plugin import install_tool_and_function_dependencies
from open_webui.utils.oauth import OAuthManager
from open_webui.utils.oauth import (
OAuthManager,
OAuthClientManager,
decrypt_data,
OAuthClientInformationFull,
)
from open_webui.utils.security_headers import SecurityHeadersMiddleware
from open_webui.utils.redis import get_redis_connection
@ -603,9 +608,14 @@ app = FastAPI(
lifespan=lifespan,
)
# For Open WebUI OIDC/OAuth2
oauth_manager = OAuthManager(app)
app.state.oauth_manager = oauth_manager
# For Integrations
oauth_client_manager = OAuthClientManager(app)
app.state.oauth_client_manager = oauth_client_manager
app.state.instance_id = None
app.state.config = AppConfig(
redis_url=REDIS_URL,
@ -1881,6 +1891,24 @@ async def get_current_usage(user=Depends(get_verified_user)):
# OAuth Login & Callback
############################
# Initialize OAuth client manager with any MCP tool servers using OAuth 2.1
if len(app.state.config.TOOL_SERVER_CONNECTIONS) > 0:
for tool_server_connection in app.state.config.TOOL_SERVER_CONNECTIONS:
if tool_server_connection.get("type", "openapi") == "mcp":
server_id = tool_server_connection.get("info", {}).get("id")
auth_type = tool_server_connection.get("auth_type", "none")
if server_id and auth_type == "oauth_2.1":
oauth_client_info = tool_server_connection.get("info", {}).get(
"oauth_client_info"
)
oauth_client_info = decrypt_data(oauth_client_info)
app.state.oauth_client_manager.add_client(
f"mcp:{server_id}", OAuthClientInformationFull(**oauth_client_info)
)
# SessionMiddleware is used by authlib for oauth
if len(OAUTH_PROVIDERS) > 0:
try:
@ -1913,6 +1941,31 @@ if len(OAUTH_PROVIDERS) > 0:
)
@app.get("/oauth/clients/{client_id}/authorize")
async def oauth_client_authorize(
client_id: str,
request: Request,
response: Response,
user=Depends(get_verified_user),
):
return await oauth_client_manager.handle_authorize(request, client_id=client_id)
@app.get("/oauth/clients/{client_id}/callback")
async def oauth_client_callback(
client_id: str,
request: Request,
response: Response,
user=Depends(get_verified_user),
):
return await oauth_client_manager.handle_callback(
request,
client_id=client_id,
user_id=user.id if user else None,
response=response,
)
@app.get("/oauth/{provider}/login")
async def oauth_login(provider: str, request: Request):
return await oauth_manager.handle_login(request, provider)
@ -1924,8 +1977,9 @@ async def oauth_login(provider: str, request: Request):
# - This is considered insecure in general, as OAuth providers do not always verify email addresses
# 3. If there is no user, and ENABLE_OAUTH_SIGNUP is true, create a user
# - Email addresses are considered unique, so we fail registration if the email address is already taken
@app.get("/oauth/{provider}/callback")
async def oauth_callback(provider: str, request: Request, response: Response):
@app.get("/oauth/{provider}/callback") # Legacy endpoint
@app.get("/oauth/{provider}/login/callback")
async def oauth_login_callback(provider: str, request: Request, response: Response):
return await oauth_manager.handle_callback(request, provider, response)

View File

@ -176,6 +176,26 @@ class OAuthSessionTable:
log.error(f"Error getting OAuth session by ID: {e}")
return None
def get_session_by_provider_and_user_id(
self, provider: str, user_id: str
) -> Optional[OAuthSessionModel]:
"""Get OAuth session by provider and user ID"""
try:
with get_db() as db:
session = (
db.query(OAuthSession)
.filter_by(provider=provider, user_id=user_id)
.first()
)
if session:
session.token = self._decrypt_token(session.token)
return OAuthSessionModel.model_validate(session)
return None
except Exception as e:
log.error(f"Error getting OAuth session by provider and user ID: {e}")
return None
def get_sessions_by_user_id(self, user_id: str) -> List[OAuthSessionModel]:
"""Get all OAuth sessions for a user"""
try:

View File

@ -21,7 +21,9 @@ from open_webui.env import SRC_LOG_LEVELS
from open_webui.utils.oauth import (
get_discovery_urls,
get_oauth_client_info_with_dynamic_client_registration,
encrypt_token,
encrypt_data,
decrypt_data,
OAuthClientInformationFull,
)
from mcp.shared.auth import OAuthMetadata
@ -103,17 +105,22 @@ class OAuthClientRegistrationForm(BaseModel):
async def register_oauth_client(
request: Request,
form_data: OAuthClientRegistrationForm,
type: Optional[str] = None,
user=Depends(get_admin_user),
):
try:
oauth_client_id = form_data.client_id
if type:
oauth_client_id = f"{type}:{form_data.client_id}"
oauth_client_info = (
await get_oauth_client_info_with_dynamic_client_registration(
request, form_data.url
request, oauth_client_id, form_data.url
)
)
return {
"status": True,
"oauth_client_info": encrypt_token(
"oauth_client_info": encrypt_data(
oauth_client_info.model_dump(mode="json")
),
}
@ -161,8 +168,25 @@ async def set_tool_servers_config(
request.app.state.config.TOOL_SERVER_CONNECTIONS = [
connection.model_dump() for connection in form_data.TOOL_SERVER_CONNECTIONS
]
await set_tool_servers(request)
for connection in request.app.state.config.TOOL_SERVER_CONNECTIONS:
server_type = connection.get("type", "openapi")
if server_type == "mcp":
server_id = connection.get("info", {}).get("id")
auth_type = connection.get("auth_type", "none")
if auth_type == "oauth_2.1" and server_id:
try:
oauth_client_info = decrypt_data(oauth_client_info)
await request.app.state.oauth_client_manager.add_client(
f"{server_type}:{server_id}",
OAuthClientInformationFull(**oauth_client_info),
)
except Exception as e:
log.debug(f"Failed to add OAuth client for MCP tool server: {e}")
continue
return {
"TOOL_SERVER_CONNECTIONS": request.app.state.config.TOOL_SERVER_CONNECTIONS,
}

View File

@ -9,6 +9,7 @@ from pydantic import BaseModel, HttpUrl
from fastapi import APIRouter, Depends, HTTPException, Request, status
from open_webui.models.oauth_sessions import OAuthSessions
from open_webui.models.tools import (
ToolForm,
ToolModel,
@ -80,6 +81,24 @@ async def get_tools(request: Request, user=Depends(get_verified_user)):
# MCP Tool Servers
for server in request.app.state.config.TOOL_SERVER_CONNECTIONS:
if server.get("type", "openapi") == "mcp":
server_id = server.get("info", {}).get("id")
auth_type = server.get("auth_type", "none")
session_token = None
if auth_type == "oauth_2.1":
splits = server_id.split(":")
server_id = splits[-1] if len(splits) > 1 else server_id
session_token = (
await request.app.state.oauth_client_manager.get_oauth_token(
user.id, f"mcp:{server_id}"
)
)
print("User ID:", user.id)
print("Server ID:", server_id)
print("MCP Session Token:", session_token)
tools.append(
ToolUserResponse(
**{
@ -96,6 +115,13 @@ async def get_tools(request: Request, user=Depends(get_verified_user)):
),
"updated_at": int(time.time()),
"created_at": int(time.time()),
**(
{
"authenticated": session_token is not None,
}
if auth_type == "oauth_2.1"
else {}
),
}
)
)

View File

@ -24,6 +24,7 @@ from fastapi.responses import HTMLResponse
from starlette.responses import Response, StreamingResponse, JSONResponse
from open_webui.models.oauth_sessions import OAuthSessions
from open_webui.models.chats import Chats
from open_webui.models.folders import Folders
from open_webui.models.users import Users
@ -1047,6 +1048,22 @@ async def process_chat_payload(request, form_data, user, metadata, model):
headers["Authorization"] = (
f"Bearer {oauth_token.get('access_token', '')}"
)
elif auth_type == "oauth_2.1":
try:
splits = server_id.split(":")
server_id = splits[-1] if len(splits) > 1 else server_id
oauth_token = await request.app.state.oauth_client_manager.get_oauth_token(
user.id, f"mcp:{server_id}"
)
if oauth_token:
headers["Authorization"] = (
f"Bearer {oauth_token.get('access_token', '')}"
)
except Exception as e:
log.error(f"Error getting OAuth token: {e}")
oauth_token = None
mcp_client = MCPClient()
await mcp_client.connect(

View File

@ -126,24 +126,24 @@ except Exception as e:
raise
def encrypt_token(token) -> str:
"""Encrypt OAuth tokens for storage"""
def encrypt_data(data) -> str:
"""Encrypt data for storage"""
try:
token_json = json.dumps(token)
encrypted = FERNET.encrypt(token_json.encode()).decode()
data_json = json.dumps(data)
encrypted = FERNET.encrypt(data_json.encode()).decode()
return encrypted
except Exception as e:
log.error(f"Error encrypting tokens: {e}")
log.error(f"Error encrypting data: {e}")
raise
def decrypt_token(token: str):
"""Decrypt OAuth tokens from storage"""
def decrypt_data(data: str):
"""Decrypt data from storage"""
try:
decrypted = FERNET.decrypt(token.encode()).decode()
decrypted = FERNET.decrypt(data.encode()).decode()
return json.loads(decrypted)
except Exception as e:
log.error(f"Error decrypting tokens: {e}")
log.error(f"Error decrypting data: {e}")
raise
@ -212,7 +212,10 @@ def get_discovery_urls(server_url) -> list[str]:
# TODO: Some OAuth providers require Initial Access Tokens (IATs) for dynamic client registration.
# This is not currently supported.
async def get_oauth_client_info_with_dynamic_client_registration(
request, oauth_server_url, oauth_server_key: Optional[str] = None
request,
client_id: str,
oauth_server_url: str,
oauth_server_key: Optional[str] = None,
) -> OAuthClientInformationFull:
try:
oauth_server_metadata = None
@ -221,9 +224,10 @@ async def get_oauth_client_info_with_dynamic_client_registration(
redirect_base_url = (
str(request.app.state.config.WEBUI_URL or request.base_url)
).rstrip("/")
oauth_client_metadata = OAuthClientMetadata(
client_name="Open WebUI",
redirect_uris=[f"{redirect_base_url}/oauth/callback"],
redirect_uris=[f"{redirect_base_url}/oauth/clients/{client_id}/callback"],
grant_types=["authorization_code", "refresh_token"],
response_types=["code"],
token_endpoint_auth_method="client_secret_post",
@ -315,23 +319,22 @@ class OAuthClientManager:
self.clients = {}
def add_client(self, client_id, oauth_client_info: OAuthClientInformationFull):
if client_id not in self.clients:
self.clients[client_id] = {
"client": self.oauth.register(
name=client_id,
client_id=oauth_client_info.client_id,
client_secret=oauth_client_info.client_secret,
client_kwargs=(
{"scope": oauth_client_info.scope}
if oauth_client_info.scope
else {}
),
server_metadata_url=(
oauth_client_info.issuer if oauth_client_info.issuer else None
),
self.clients[client_id] = {
"client": self.oauth.register(
name=client_id,
client_id=oauth_client_info.client_id,
client_secret=oauth_client_info.client_secret,
client_kwargs=(
{"scope": oauth_client_info.scope}
if oauth_client_info.scope
else {}
),
"client_info": oauth_client_info,
}
server_metadata_url=(
oauth_client_info.issuer if oauth_client_info.issuer else None
),
),
"client_info": oauth_client_info,
}
return self.clients[client_id]
def remove_client(self, client_id):
@ -359,7 +362,7 @@ class OAuthClientManager:
return None
async def get_oauth_token(
self, user_id: str, session_id: str, force_refresh: bool = False
self, user_id: str, client_id: str, force_refresh: bool = False
):
"""
Get a valid OAuth token for the user, automatically refreshing if needed.
@ -374,10 +377,12 @@ class OAuthClientManager:
"""
try:
# Get the OAuth session
session = OAuthSessions.get_session_by_id_and_user_id(session_id, user_id)
session = OAuthSessions.get_session_by_provider_and_user_id(
client_id, user_id
)
if not session:
log.warning(
f"No OAuth session found for user {user_id}, session {session_id}"
f"No OAuth session found for user {user_id}, client_id {client_id}"
)
return None
@ -392,8 +397,9 @@ class OAuthClientManager:
return refreshed_token
else:
log.warning(
f"Token refresh failed for user {user_id}, client_id {session.provider}"
f"Token refresh failed for user {user_id}, client_id {session.provider}, deleting session {session.id}"
)
OAuthSessions.delete_session_by_id(session.id)
return None
return session.token
@ -533,7 +539,7 @@ class OAuthClientManager:
redirect_uri = (
client_info.redirect_uris[0] if client_info.redirect_uris else None
)
return await client.authorize_redirect(request, redirect_uri)
return await client.authorize_redirect(request, str(redirect_uri))
async def handle_callback(self, request, client_id: str, user_id: str, response):
client = self.get_client(client_id)
@ -565,7 +571,6 @@ class OAuthClientManager:
provider=client_id,
token=token,
)
log.info(
f"Stored OAuth session server-side for user {user_id}, client_id {client_id}"
)
@ -579,16 +584,17 @@ class OAuthClientManager:
error_message = "OAuth callback error"
log.warning(f"OAuth callback error: {e}")
redirect_base_url = (
redirect_url = (
str(request.app.state.config.WEBUI_URL or request.base_url)
).rstrip("/")
redirect_url = f"{redirect_base_url}/auth"
if error_message:
redirect_url = f"{redirect_url}?error={error_message}"
log.debug(error_message)
redirect_url = f"{redirect_url}/?error={error_message}"
return RedirectResponse(url=redirect_url, headers=response.headers)
response = RedirectResponse(url=redirect_url, headers=response.headers)
return response
class OAuthManager:
@ -649,8 +655,10 @@ class OAuthManager:
return refreshed_token
else:
log.warning(
f"Token refresh failed for user {user_id}, provider {session.provider}"
f"Token refresh failed for user {user_id}, provider {session.provider}, deleting session {session.id}"
)
OAuthSessions.delete_session_by_id(session.id)
return None
return session.token

View File

@ -1,4 +1,4 @@
import { WEBUI_API_BASE_URL } from '$lib/constants';
import { WEBUI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants';
import type { Banner } from '$lib/types';
export const importConfig = async (token: string, config) => {
@ -208,10 +208,15 @@ type RegisterOAuthClientForm = {
client_name?: string;
};
export const registerOAuthClient = async (token: string, formData: RegisterOAuthClientForm) => {
export const registerOAuthClient = async (
token: string,
formData: RegisterOAuthClientForm,
type: null | string = null
) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/configs/oauth/clients/register`, {
const searchParams = type ? `?type=${type}` : '';
const res = await fetch(`${WEBUI_API_BASE_URL}/configs/oauth/clients/register${searchParams}`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
@ -238,6 +243,11 @@ export const registerOAuthClient = async (token: string, formData: RegisterOAuth
return res;
};
export const getOAuthClientAuthorizationUrl = (clientId: string, type: null | string = null) => {
const oauthClientId = type ? `${type}:${clientId}` : clientId;
return `${WEBUI_BASE_URL}/oauth/clients/${oauthClientId}/authorize`;
};
export const getCodeExecutionConfig = async (token: string) => {
let error = null;

View File

@ -57,16 +57,26 @@
return;
}
const res = await registerOAuthClient(localStorage.token, {
url: url,
client_id: id
}).catch((err) => {
const res = await registerOAuthClient(
localStorage.token,
{
url: url,
client_id: id
},
'mcp'
).catch((err) => {
toast.error($i18n.t('Registration failed'));
return null;
});
if (res) {
toast.warning(
$i18n.t(
'Please save the connection to persist the OAuth client information and do not change the ID'
)
);
toast.success($i18n.t('Registration successful'));
console.debug('Registration successful', res);
oauthClientInfo = res?.oauth_client_info ?? null;
}

View File

@ -20,6 +20,8 @@
import ChevronRight from '$lib/components/icons/ChevronRight.svelte';
import ChevronLeft from '$lib/components/icons/ChevronLeft.svelte';
import ValvesModal from '$lib/components/workspace/common/ValvesModal.svelte';
import { getOAuthClientAuthorizationUrl } from '$lib/apis/configs';
import { partition } from 'd3-hierarchy';
const i18n = getContext('i18n');
@ -321,11 +323,25 @@
{#each Object.keys(tools) as toolId}
<button
class="flex w-full justify-between gap-2 items-center px-3 py-1.5 text-sm cursor-pointer rounded-xl hover:bg-gray-50 dark:hover:bg-gray-800/50"
on:click={() => {
tools[toolId].enabled = !tools[toolId].enabled;
class="relative flex w-full justify-between gap-2 items-center px-3 py-1.5 text-sm cursor-pointer rounded-xl hover:bg-gray-50 dark:hover:bg-gray-800/50"
on:click={(e) => {
if (!(tools[toolId]?.authenticated ?? true)) {
e.preventDefault();
let parts = toolId.split(':');
let serverId = parts?.at(-1) ?? toolId;
const authUrl = getOAuthClientAuthorizationUrl(serverId, 'mcp');
window.open(authUrl, '_blank', 'noopener');
} else {
tools[toolId].enabled = !tools[toolId].enabled;
}
}}
>
{#if !(tools[toolId]?.authenticated ?? true)}
<!-- make it slighly darker and not clickable -->
<div class="absolute inset-0 opacity-50 rounded-xl cursor-not-allowed z-10" />
{/if}
<div class="flex-1 truncate">
<div class="flex flex-1 gap-2 items-center">
<Tooltip content={tools[toolId]?.name ?? ''} placement="top">

View File

@ -1,5 +1,15 @@
<script lang="ts">
import { onMount } from 'svelte';
import { toast } from 'svelte-sonner';
import Chat from '$lib/components/chat/Chat.svelte';
import { page } from '$app/stores';
onMount(() => {
if ($page.url.searchParams.get('error')) {
toast.error($page.url.searchParams.get('error') || 'An unknown error occurred.');
}
});
</script>
<Chat />