refac
This commit is contained in:
parent
3aad157006
commit
742e2ff193
|
@ -1552,7 +1552,7 @@ async def chat_completion(
|
||||||
finally:
|
finally:
|
||||||
try:
|
try:
|
||||||
if mcp_clients := metadata.get("mcp_clients"):
|
if mcp_clients := metadata.get("mcp_clients"):
|
||||||
for client in mcp_clients:
|
for client in mcp_clients.values():
|
||||||
await client.disconnect()
|
await client.disconnect()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.debug(f"Error cleaning up: {e}")
|
log.debug(f"Error cleaning up: {e}")
|
||||||
|
|
|
@ -1096,7 +1096,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||||
|
|
||||||
tools_dict = {}
|
tools_dict = {}
|
||||||
|
|
||||||
mcp_clients = []
|
mcp_clients = {}
|
||||||
mcp_tools_dict = {}
|
mcp_tools_dict = {}
|
||||||
|
|
||||||
if tool_ids:
|
if tool_ids:
|
||||||
|
@ -1157,25 +1157,30 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||||
log.error(f"Error getting OAuth token: {e}")
|
log.error(f"Error getting OAuth token: {e}")
|
||||||
oauth_token = None
|
oauth_token = None
|
||||||
|
|
||||||
mcp_client = MCPClient()
|
mcp_clients[server_id] = MCPClient()
|
||||||
await mcp_client.connect(
|
await mcp_clients[server_id].connect(
|
||||||
url=mcp_server_connection.get("url", ""),
|
url=mcp_server_connection.get("url", ""),
|
||||||
headers=headers if headers else None,
|
headers=headers if headers else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
tool_specs = await mcp_client.list_tool_specs()
|
tool_specs = await mcp_clients[server_id].list_tool_specs()
|
||||||
for tool_spec in tool_specs:
|
for tool_spec in tool_specs:
|
||||||
|
|
||||||
def make_tool_function(function_name):
|
def make_tool_function(client, function_name):
|
||||||
async def tool_function(**kwargs):
|
async def tool_function(**kwargs):
|
||||||
return await mcp_client.call_tool(
|
print(kwargs)
|
||||||
|
print(client)
|
||||||
|
print(await client.list_tool_specs())
|
||||||
|
return await client.call_tool(
|
||||||
function_name,
|
function_name,
|
||||||
function_args=kwargs,
|
function_args=kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return tool_function
|
return tool_function
|
||||||
|
|
||||||
tool_function = make_tool_function(tool_spec["name"])
|
tool_function = make_tool_function(
|
||||||
|
mcp_clients[server_id], tool_spec["name"]
|
||||||
|
)
|
||||||
|
|
||||||
mcp_tools_dict[f"{server_id}_{tool_spec['name']}"] = {
|
mcp_tools_dict[f"{server_id}_{tool_spec['name']}"] = {
|
||||||
"spec": {
|
"spec": {
|
||||||
|
@ -1184,11 +1189,9 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||||
},
|
},
|
||||||
"callable": tool_function,
|
"callable": tool_function,
|
||||||
"type": "mcp",
|
"type": "mcp",
|
||||||
"client": mcp_client,
|
"client": mcp_clients[server_id],
|
||||||
"direct": False,
|
"direct": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
mcp_clients.append(mcp_client)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.debug(e)
|
log.debug(e)
|
||||||
continue
|
continue
|
||||||
|
|
Loading…
Reference in New Issue