This commit is contained in:
Timothy Jaeryang Baek 2025-09-28 12:42:02 -05:00
parent 3aad157006
commit 742e2ff193
2 changed files with 14 additions and 11 deletions

View File

@ -1552,7 +1552,7 @@ async def chat_completion(
finally: finally:
try: try:
if mcp_clients := metadata.get("mcp_clients"): if mcp_clients := metadata.get("mcp_clients"):
for client in mcp_clients: for client in mcp_clients.values():
await client.disconnect() await client.disconnect()
except Exception as e: except Exception as e:
log.debug(f"Error cleaning up: {e}") log.debug(f"Error cleaning up: {e}")

View File

@ -1096,7 +1096,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
tools_dict = {} tools_dict = {}
mcp_clients = [] mcp_clients = {}
mcp_tools_dict = {} mcp_tools_dict = {}
if tool_ids: if tool_ids:
@ -1157,25 +1157,30 @@ async def process_chat_payload(request, form_data, user, metadata, model):
log.error(f"Error getting OAuth token: {e}") log.error(f"Error getting OAuth token: {e}")
oauth_token = None oauth_token = None
mcp_client = MCPClient() mcp_clients[server_id] = MCPClient()
await mcp_client.connect( await mcp_clients[server_id].connect(
url=mcp_server_connection.get("url", ""), url=mcp_server_connection.get("url", ""),
headers=headers if headers else None, headers=headers if headers else None,
) )
tool_specs = await mcp_client.list_tool_specs() tool_specs = await mcp_clients[server_id].list_tool_specs()
for tool_spec in tool_specs: for tool_spec in tool_specs:
def make_tool_function(function_name): def make_tool_function(client, function_name):
async def tool_function(**kwargs): async def tool_function(**kwargs):
return await mcp_client.call_tool( print(kwargs)
print(client)
print(await client.list_tool_specs())
return await client.call_tool(
function_name, function_name,
function_args=kwargs, function_args=kwargs,
) )
return tool_function return tool_function
tool_function = make_tool_function(tool_spec["name"]) tool_function = make_tool_function(
mcp_clients[server_id], tool_spec["name"]
)
mcp_tools_dict[f"{server_id}_{tool_spec['name']}"] = { mcp_tools_dict[f"{server_id}_{tool_spec['name']}"] = {
"spec": { "spec": {
@ -1184,11 +1189,9 @@ async def process_chat_payload(request, form_data, user, metadata, model):
}, },
"callable": tool_function, "callable": tool_function,
"type": "mcp", "type": "mcp",
"client": mcp_client, "client": mcp_clients[server_id],
"direct": False, "direct": False,
} }
mcp_clients.append(mcp_client)
except Exception as e: except Exception as e:
log.debug(e) log.debug(e)
continue continue