Spaces:
Sleeping
Sleeping
| # utils/mcp_client.py | |
| import asyncio | |
| import os | |
| from typing import Dict, Any, List, Optional | |
| from mcp import ClientSession, StdioServerParameters | |
| from mcp.client.stdio import stdio_client | |
| class MCPMux: | |
| def __init__(self): | |
| self._servers = {} | |
| self._tool_index = {} | |
| self._streams = {} # Store streams to keep them alive | |
| async def connect_stdio(self, name: str, command: str, args: Optional[List[str]] = None, env: Optional[Dict[str,str]] = None): | |
| # Inherit current env so secrets (e.g., ANTHROPIC_API_KEY) reach child servers | |
| inherited_env = os.environ.copy() | |
| if env: | |
| inherited_env.update(env) | |
| params = StdioServerParameters(command=command, args=args or [], env=inherited_env) | |
| # Use async context manager for stdio_client | |
| stdio = stdio_client(params) | |
| read, write = await stdio.__aenter__() | |
| # Store the context manager to keep streams alive | |
| self._streams[name] = stdio | |
| # Create session | |
| session = ClientSession(read, write) | |
| # Start session | |
| await session.__aenter__() | |
| # Initialize | |
| await session.initialize() | |
| # List tools | |
| tools_result = await session.list_tools() | |
| tools = tools_result.tools | |
| # Store session and tools | |
| self._servers[name] = {"session": session, "tools": {t.name: t for t in tools}} | |
| for t in tools: | |
| self._tool_index[t.name] = name | |
| return tools | |
| async def call(self, tool_name: str, arguments: Dict[str, Any]) -> Any: | |
| if tool_name not in self._tool_index: | |
| raise ValueError(f"Unknown tool: {tool_name}") | |
| server_name = self._tool_index[tool_name] | |
| session = self._servers[server_name]["session"] | |
| print(f"π§ MCP call: {server_name}.{tool_name}({arguments})") | |
| res = await session.call_tool(tool_name, arguments=arguments) | |
| print(f"β MCP response type: {type(res)}, hasattr content: {hasattr(res, 'content')}") | |
| parts = [] | |
| for c in res.content: | |
| if hasattr(c, "text") and c.text: | |
| parts.append(c.text) | |
| elif hasattr(c, "data"): | |
| parts.append(str(c.data)) | |
| result = "\n".join(parts) if parts else "(no content)" | |
| print(f"π€ MCP result: {result[:200]}...") | |
| return result | |
| async def list_all_tools(self): | |
| out = {} | |
| for name, s in self._servers.items(): | |
| out[name] = list(s["tools"].keys()) | |
| return out | |
| async def close(self): | |
| for s in self._servers.values(): | |
| await s["session"].__aexit__(None, None, None) | |
| for s in self._servers.values(): | |
| await s["conn"].close() | |