refac
This commit is contained in:
parent
c55afc4255
commit
61f20acf61
|
@ -133,39 +133,44 @@ async def verify_tool_servers_config(
|
|||
try:
|
||||
if form_data.type == "mcp":
|
||||
try:
|
||||
async with MCPClient() as client:
|
||||
auth = None
|
||||
headers = None
|
||||
client = MCPClient()
|
||||
auth = None
|
||||
headers = None
|
||||
|
||||
token = None
|
||||
if form_data.auth_type == "bearer":
|
||||
token = form_data.key
|
||||
elif form_data.auth_type == "session":
|
||||
token = request.state.token.credentials
|
||||
elif form_data.auth_type == "system_oauth":
|
||||
try:
|
||||
if request.cookies.get("oauth_session_id", None):
|
||||
token = await request.app.state.oauth_manager.get_oauth_token(
|
||||
token = None
|
||||
if form_data.auth_type == "bearer":
|
||||
token = form_data.key
|
||||
elif form_data.auth_type == "session":
|
||||
token = request.state.token.credentials
|
||||
elif form_data.auth_type == "system_oauth":
|
||||
try:
|
||||
if request.cookies.get("oauth_session_id", None):
|
||||
token = (
|
||||
await request.app.state.oauth_manager.get_oauth_token(
|
||||
user.id,
|
||||
request.cookies.get("oauth_session_id", None),
|
||||
)
|
||||
except Exception as e:
|
||||
pass
|
||||
)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
if token:
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
if token:
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
await client.connect(form_data.url, auth=auth, headers=headers)
|
||||
specs = await client.list_tool_specs()
|
||||
return {
|
||||
"status": True,
|
||||
"specs": specs,
|
||||
}
|
||||
await client.connect(form_data.url, auth=auth, headers=headers)
|
||||
specs = await client.list_tool_specs()
|
||||
return {
|
||||
"status": True,
|
||||
"specs": specs,
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to create MCP client: {str(e)}",
|
||||
)
|
||||
finally:
|
||||
if client:
|
||||
await client.disconnect()
|
||||
else: # openapi
|
||||
token = None
|
||||
if form_data.auth_type == "bearer":
|
||||
|
|
|
@ -16,19 +16,25 @@ class MCPClient:
|
|||
async def connect(
|
||||
self, url: str, headers: Optional[dict] = None, auth: Optional[any] = None
|
||||
):
|
||||
self._streams_context = streamablehttp_client(url, headers=headers, auth=auth)
|
||||
read_stream, write_stream, _ = (
|
||||
await self._streams_context.__aenter__()
|
||||
) # pylint: disable=E1101
|
||||
try:
|
||||
self._streams_context = streamablehttp_client(
|
||||
url, headers=headers, auth=auth
|
||||
)
|
||||
|
||||
self._session_context = ClientSession(
|
||||
read_stream, write_stream
|
||||
) # pylint: disable=W0201
|
||||
self.session: ClientSession = (
|
||||
await self._session_context.__aenter__()
|
||||
) # pylint: disable=C2801
|
||||
transport = await self.exit_stack.enter_async_context(self._streams_context)
|
||||
read_stream, write_stream, _ = transport
|
||||
|
||||
await self.session.initialize()
|
||||
self._session_context = ClientSession(
|
||||
read_stream, write_stream
|
||||
) # pylint: disable=W0201
|
||||
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
self._session_context
|
||||
)
|
||||
await self.session.initialize()
|
||||
except Exception as e:
|
||||
await self.disconnect()
|
||||
raise e
|
||||
|
||||
async def list_tool_specs(self) -> Optional[dict]:
|
||||
if not self.session:
|
||||
|
@ -97,15 +103,7 @@ class MCPClient:
|
|||
|
||||
async def disconnect(self):
|
||||
# Clean up and close the session
|
||||
if self.session:
|
||||
await self._session_context.__aexit__(
|
||||
None, None, None
|
||||
) # pylint: disable=E1101
|
||||
if self._streams_context:
|
||||
await self._streams_context.__aexit__(
|
||||
None, None, None
|
||||
) # pylint: disable=E1101
|
||||
self.session = None
|
||||
await self.exit_stack.aclose()
|
||||
|
||||
async def __aenter__(self):
|
||||
await self.exit_stack.__aenter__()
|
||||
|
|
Loading…
Reference in New Issue