refac
This commit is contained in:
parent
c55afc4255
commit
61f20acf61
|
@ -133,7 +133,7 @@ async def verify_tool_servers_config(
|
||||||
try:
|
try:
|
||||||
if form_data.type == "mcp":
|
if form_data.type == "mcp":
|
||||||
try:
|
try:
|
||||||
async with MCPClient() as client:
|
client = MCPClient()
|
||||||
auth = None
|
auth = None
|
||||||
headers = None
|
headers = None
|
||||||
|
|
||||||
|
@ -145,10 +145,12 @@ async def verify_tool_servers_config(
|
||||||
elif form_data.auth_type == "system_oauth":
|
elif form_data.auth_type == "system_oauth":
|
||||||
try:
|
try:
|
||||||
if request.cookies.get("oauth_session_id", None):
|
if request.cookies.get("oauth_session_id", None):
|
||||||
token = await request.app.state.oauth_manager.get_oauth_token(
|
token = (
|
||||||
|
await request.app.state.oauth_manager.get_oauth_token(
|
||||||
user.id,
|
user.id,
|
||||||
request.cookies.get("oauth_session_id", None),
|
request.cookies.get("oauth_session_id", None),
|
||||||
)
|
)
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -166,6 +168,9 @@ async def verify_tool_servers_config(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
detail=f"Failed to create MCP client: {str(e)}",
|
detail=f"Failed to create MCP client: {str(e)}",
|
||||||
)
|
)
|
||||||
|
finally:
|
||||||
|
if client:
|
||||||
|
await client.disconnect()
|
||||||
else: # openapi
|
else: # openapi
|
||||||
token = None
|
token = None
|
||||||
if form_data.auth_type == "bearer":
|
if form_data.auth_type == "bearer":
|
||||||
|
|
|
@ -16,19 +16,25 @@ class MCPClient:
|
||||||
async def connect(
|
async def connect(
|
||||||
self, url: str, headers: Optional[dict] = None, auth: Optional[any] = None
|
self, url: str, headers: Optional[dict] = None, auth: Optional[any] = None
|
||||||
):
|
):
|
||||||
self._streams_context = streamablehttp_client(url, headers=headers, auth=auth)
|
try:
|
||||||
read_stream, write_stream, _ = (
|
self._streams_context = streamablehttp_client(
|
||||||
await self._streams_context.__aenter__()
|
url, headers=headers, auth=auth
|
||||||
) # pylint: disable=E1101
|
)
|
||||||
|
|
||||||
|
transport = await self.exit_stack.enter_async_context(self._streams_context)
|
||||||
|
read_stream, write_stream, _ = transport
|
||||||
|
|
||||||
self._session_context = ClientSession(
|
self._session_context = ClientSession(
|
||||||
read_stream, write_stream
|
read_stream, write_stream
|
||||||
) # pylint: disable=W0201
|
) # pylint: disable=W0201
|
||||||
self.session: ClientSession = (
|
|
||||||
await self._session_context.__aenter__()
|
|
||||||
) # pylint: disable=C2801
|
|
||||||
|
|
||||||
|
self.session = await self.exit_stack.enter_async_context(
|
||||||
|
self._session_context
|
||||||
|
)
|
||||||
await self.session.initialize()
|
await self.session.initialize()
|
||||||
|
except Exception as e:
|
||||||
|
await self.disconnect()
|
||||||
|
raise e
|
||||||
|
|
||||||
async def list_tool_specs(self) -> Optional[dict]:
|
async def list_tool_specs(self) -> Optional[dict]:
|
||||||
if not self.session:
|
if not self.session:
|
||||||
|
@ -97,15 +103,7 @@ class MCPClient:
|
||||||
|
|
||||||
async def disconnect(self):
|
async def disconnect(self):
|
||||||
# Clean up and close the session
|
# Clean up and close the session
|
||||||
if self.session:
|
await self.exit_stack.aclose()
|
||||||
await self._session_context.__aexit__(
|
|
||||||
None, None, None
|
|
||||||
) # pylint: disable=E1101
|
|
||||||
if self._streams_context:
|
|
||||||
await self._streams_context.__aexit__(
|
|
||||||
None, None, None
|
|
||||||
) # pylint: disable=E1101
|
|
||||||
self.session = None
|
|
||||||
|
|
||||||
async def __aenter__(self):
|
async def __aenter__(self):
|
||||||
await self.exit_stack.__aenter__()
|
await self.exit_stack.__aenter__()
|
||||||
|
|
Loading…
Reference in New Issue