From 742e2ff193a153bbb65f909d1296bbcf31f062fa Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Sun, 28 Sep 2025 12:42:02 -0500 Subject: [PATCH] refac --- backend/open_webui/main.py | 2 +- backend/open_webui/utils/middleware.py | 23 +++++++++++++---------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index f38bd47109..61a1639ee3 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -1552,7 +1552,7 @@ async def chat_completion( finally: try: if mcp_clients := metadata.get("mcp_clients"): - for client in mcp_clients: + for client in mcp_clients.values(): await client.disconnect() except Exception as e: log.debug(f"Error cleaning up: {e}") diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 5a2b46d3f6..0a55ae76f0 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -1096,7 +1096,7 @@ async def process_chat_payload(request, form_data, user, metadata, model): tools_dict = {} - mcp_clients = [] + mcp_clients = {} mcp_tools_dict = {} if tool_ids: @@ -1157,25 +1157,30 @@ async def process_chat_payload(request, form_data, user, metadata, model): log.error(f"Error getting OAuth token: {e}") oauth_token = None - mcp_client = MCPClient() - await mcp_client.connect( + mcp_clients[server_id] = MCPClient() + await mcp_clients[server_id].connect( url=mcp_server_connection.get("url", ""), headers=headers if headers else None, ) - tool_specs = await mcp_client.list_tool_specs() + tool_specs = await mcp_clients[server_id].list_tool_specs() for tool_spec in tool_specs: - def make_tool_function(function_name): + def make_tool_function(client, function_name): async def tool_function(**kwargs): - return await mcp_client.call_tool( + print(kwargs) + print(client) + print(await client.list_tool_specs()) + return await client.call_tool( function_name, function_args=kwargs, ) return tool_function - tool_function = make_tool_function(tool_spec["name"]) + tool_function = make_tool_function( + mcp_clients[server_id], tool_spec["name"] + ) mcp_tools_dict[f"{server_id}_{tool_spec['name']}"] = { "spec": { @@ -1184,11 +1189,9 @@ async def process_chat_payload(request, form_data, user, metadata, model): }, "callable": tool_function, "type": "mcp", - "client": mcp_client, + "client": mcp_clients[server_id], "direct": False, } - - mcp_clients.append(mcp_client) except Exception as e: log.debug(e) continue