This commit is contained in:
Timothy Jaeryang Baek 2025-09-23 03:32:25 -04:00
parent c55afc4255
commit 61f20acf61
2 changed files with 45 additions and 42 deletions

View File

@ -133,39 +133,44 @@ async def verify_tool_servers_config(
try:
if form_data.type == "mcp":
try:
async with MCPClient() as client:
auth = None
headers = None
client = MCPClient()
auth = None
headers = None
token = None
if form_data.auth_type == "bearer":
token = form_data.key
elif form_data.auth_type == "session":
token = request.state.token.credentials
elif form_data.auth_type == "system_oauth":
try:
if request.cookies.get("oauth_session_id", None):
token = await request.app.state.oauth_manager.get_oauth_token(
token = None
if form_data.auth_type == "bearer":
token = form_data.key
elif form_data.auth_type == "session":
token = request.state.token.credentials
elif form_data.auth_type == "system_oauth":
try:
if request.cookies.get("oauth_session_id", None):
token = (
await request.app.state.oauth_manager.get_oauth_token(
user.id,
request.cookies.get("oauth_session_id", None),
)
except Exception as e:
pass
)
except Exception as e:
pass
if token:
headers = {"Authorization": f"Bearer {token}"}
if token:
headers = {"Authorization": f"Bearer {token}"}
await client.connect(form_data.url, auth=auth, headers=headers)
specs = await client.list_tool_specs()
return {
"status": True,
"specs": specs,
}
await client.connect(form_data.url, auth=auth, headers=headers)
specs = await client.list_tool_specs()
return {
"status": True,
"specs": specs,
}
except Exception as e:
raise HTTPException(
status_code=400,
detail=f"Failed to create MCP client: {str(e)}",
)
finally:
if client:
await client.disconnect()
else: # openapi
token = None
if form_data.auth_type == "bearer":

View File

@ -16,19 +16,25 @@ class MCPClient:
async def connect(
self, url: str, headers: Optional[dict] = None, auth: Optional[any] = None
):
self._streams_context = streamablehttp_client(url, headers=headers, auth=auth)
read_stream, write_stream, _ = (
await self._streams_context.__aenter__()
) # pylint: disable=E1101
try:
self._streams_context = streamablehttp_client(
url, headers=headers, auth=auth
)
self._session_context = ClientSession(
read_stream, write_stream
) # pylint: disable=W0201
self.session: ClientSession = (
await self._session_context.__aenter__()
) # pylint: disable=C2801
transport = await self.exit_stack.enter_async_context(self._streams_context)
read_stream, write_stream, _ = transport
await self.session.initialize()
self._session_context = ClientSession(
read_stream, write_stream
) # pylint: disable=W0201
self.session = await self.exit_stack.enter_async_context(
self._session_context
)
await self.session.initialize()
except Exception as e:
await self.disconnect()
raise e
async def list_tool_specs(self) -> Optional[dict]:
if not self.session:
@ -97,15 +103,7 @@ class MCPClient:
async def disconnect(self):
# Clean up and close the session
if self.session:
await self._session_context.__aexit__(
None, None, None
) # pylint: disable=E1101
if self._streams_context:
await self._streams_context.__aexit__(
None, None, None
) # pylint: disable=E1101
self.session = None
await self.exit_stack.aclose()
async def __aenter__(self):
await self.exit_stack.__aenter__()