ghostMalone / utils /mcp_client.py
francischung222's picture
mcp bug with for huggyface secret
b0dd11f
raw
history blame
2.75 kB
# utils/mcp_client.py
import asyncio
import os
from typing import Dict, Any, List, Optional
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
class MCPMux:
def __init__(self):
self._servers = {}
self._tool_index = {}
self._streams = {} # Store streams to keep them alive
async def connect_stdio(self, name: str, command: str, args: Optional[List[str]] = None, env: Optional[Dict[str,str]] = None):
# Inherit current env so secrets (e.g., ANTHROPIC_API_KEY) reach child servers
inherited_env = os.environ.copy()
if env:
inherited_env.update(env)
params = StdioServerParameters(command=command, args=args or [], env=inherited_env)
# Use async context manager for stdio_client
stdio = stdio_client(params)
read, write = await stdio.__aenter__()
# Store the context manager to keep streams alive
self._streams[name] = stdio
# Create session
session = ClientSession(read, write)
# Start session
await session.__aenter__()
# Initialize
await session.initialize()
# List tools
tools_result = await session.list_tools()
tools = tools_result.tools
# Store session and tools
self._servers[name] = {"session": session, "tools": {t.name: t for t in tools}}
for t in tools:
self._tool_index[t.name] = name
return tools
async def call(self, tool_name: str, arguments: Dict[str, Any]) -> Any:
if tool_name not in self._tool_index:
raise ValueError(f"Unknown tool: {tool_name}")
server_name = self._tool_index[tool_name]
session = self._servers[server_name]["session"]
print(f"πŸ”§ MCP call: {server_name}.{tool_name}({arguments})")
res = await session.call_tool(tool_name, arguments=arguments)
print(f"βœ… MCP response type: {type(res)}, hasattr content: {hasattr(res, 'content')}")
parts = []
for c in res.content:
if hasattr(c, "text") and c.text:
parts.append(c.text)
elif hasattr(c, "data"):
parts.append(str(c.data))
result = "\n".join(parts) if parts else "(no content)"
print(f"πŸ“€ MCP result: {result[:200]}...")
return result
async def list_all_tools(self):
out = {}
for name, s in self._servers.items():
out[name] = list(s["tools"].keys())
return out
async def close(self):
for s in self._servers.values():
await s["session"].__aexit__(None, None, None)
for s in self._servers.values():
await s["conn"].close()