Skip to content

Commit bef6a26

Browse files
vblagojedavidsbatistaAmnah199
authored
feat: Add Secret support to MCPTool and MCPToolset (#1900)
* Add Secret support to MCPTool and MCPToolset * Minor fixes * Minor fix * Update integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py Co-authored-by: Amna Mubashar <amnahkhan.ak@gmail.com> * Update integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py Co-authored-by: Amna Mubashar <amnahkhan.ak@gmail.com> * Post merge fixes --------- Co-authored-by: David S. Batista <dsbatista@gmail.com> Co-authored-by: Amna Mubashar <amnahkhan.ak@gmail.com>
1 parent ec040be commit bef6a26

File tree

2 files changed

+275
-24
lines changed

2 files changed

+275
-24
lines changed

integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py

Lines changed: 93 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from haystack.core.serialization import generate_qualified_class_name, import_class_by_name
2020
from haystack.tools import Tool
2121
from haystack.tools.errors import ToolInvocationError
22+
from haystack.utils import Secret, deserialize_secrets_inplace
23+
from haystack.utils.auth import SecretType
2224
from haystack.utils.url_validation import is_valid_http_url
2325

2426
from mcp import ClientSession, StdioServerParameters, types
@@ -356,7 +358,7 @@ class StdioClient(MCPClient):
356358
MCP client that connects to servers using stdio transport.
357359
"""
358360

359-
def __init__(self, command: str, args: list[str] | None = None, env: dict[str, str] | None = None) -> None:
361+
def __init__(self, command: str, args: list[str] | None = None, env: dict[str, str | Secret] | None = None) -> None:
360362
"""
361363
Initialize a stdio MCP client.
362364
@@ -367,7 +369,12 @@ def __init__(self, command: str, args: list[str] | None = None, env: dict[str, s
367369
super().__init__()
368370
self.command: str = command
369371
self.args: list[str] = args or []
370-
self.env: dict[str, str] | None = env
372+
# Resolve Secret values in environment variables
373+
self.env: dict[str, str] | None = None
374+
if env:
375+
self.env = {
376+
key: value.resolve_value() if isinstance(value, Secret) else value for key, value in env.items()
377+
}
371378
logger.debug(f"PROCESS: Created StdioClient for command: {command} {' '.join(self.args or [])}")
372379

373380
async def connect(self) -> list[Tool]:
@@ -400,7 +407,9 @@ def __init__(self, server_info: "SSEServerInfo") -> None:
400407
# in post_init we validate the url and set the url field so it is guaranteed to be valid
401408
# safely ignore the mypy warning here
402409
self.url: str = server_info.url # type: ignore[assignment]
403-
self.token: str | None = server_info.token
410+
self.token: str | None = (
411+
server_info.token.resolve_value() if isinstance(server_info.token, Secret) else server_info.token
412+
)
404413
self.timeout: int = server_info.timeout
405414

406415
async def connect(self) -> list[Tool]:
@@ -431,7 +440,9 @@ def __init__(self, server_info: "StreamableHttpServerInfo") -> None:
431440
super().__init__()
432441

433442
self.url: str = server_info.url
434-
self.token: str | None = server_info.token
443+
self.token: str | None = (
444+
server_info.token.resolve_value() if isinstance(server_info.token, Secret) else server_info.token
445+
)
435446
self.timeout: int = server_info.timeout
436447

437448
async def connect(self) -> list[Tool]:
@@ -441,6 +452,13 @@ async def connect(self) -> list[Tool]:
441452
:returns: List of available tools on the server
442453
:raises MCPConnectionError: If connection to the server fails
443454
"""
455+
if streamablehttp_client is None:
456+
message = (
457+
"Streamable HTTP client is not available. "
458+
"This may require a newer version of the mcp package that includes mcp.client.streamable_http"
459+
)
460+
raise MCPConnectionError(message=message, operation="streamable_http_connect")
461+
444462
headers = {"Authorization": f"Bearer {self.token}"} if self.token else None
445463
streamablehttp_transport = await self.exit_stack.enter_async_context(
446464
streamablehttp_client(url=self.url, headers=headers, timeout=timedelta(seconds=self.timeout))
@@ -475,8 +493,16 @@ def to_dict(self) -> dict[str, Any]:
475493
result = {"type": generate_qualified_class_name(type(self))}
476494

477495
# Add all fields from the dataclass
478-
for field in fields(self):
479-
result[field.name] = getattr(self, field.name)
496+
for dataclass_field in fields(self):
497+
value = getattr(self, dataclass_field.name)
498+
if hasattr(value, "to_dict"):
499+
result[dataclass_field.name] = value.to_dict()
500+
elif isinstance(value, dict):
501+
result[dataclass_field.name] = {
502+
k: v.to_dict() if hasattr(v, "to_dict") else v for k, v in value.items()
503+
}
504+
else:
505+
result[dataclass_field.name] = value
480506

481507
return result
482508

@@ -492,6 +518,26 @@ def from_dict(cls, data: dict[str, Any]) -> "MCPServerInfo":
492518
data_copy = data.copy()
493519
data_copy.pop("type", None)
494520

521+
secret_types = {e.value for e in SecretType}
522+
field_names = {f.name for f in fields(cls)}
523+
524+
# Iterate over a static list of items to avoid mutation issues
525+
for name, value in list(data_copy.items()):
526+
if name not in field_names or not isinstance(value, dict):
527+
continue
528+
529+
# Top-level secret?
530+
if value.get("type") in secret_types:
531+
deserialize_secrets_inplace(data_copy, keys=[name])
532+
continue
533+
534+
# Nested secrets (one level deep)
535+
nested_keys: list[str] = [
536+
k for k, v in value.items() if isinstance(v, dict) and v.get("type") in secret_types
537+
]
538+
if nested_keys:
539+
deserialize_secrets_inplace(value, keys=nested_keys)
540+
495541
# Create an instance of the class with the remaining fields
496542
return cls(**data_copy)
497543

@@ -501,6 +547,16 @@ class SSEServerInfo(MCPServerInfo):
501547
"""
502548
Data class that encapsulates SSE MCP server connection parameters.
503549
550+
For authentication tokens containing sensitive data, you can use Secret objects
551+
for secure handling and serialization:
552+
553+
```python
554+
server_info = SSEServerInfo(
555+
url="https://my-mcp-server.com",
556+
token=Secret.from_env_var("API_KEY"),
557+
)
558+
```
559+
504560
:param url: Full URL of the MCP server (including /sse endpoint)
505561
:param base_url: Base URL of the MCP server (deprecated, use url instead)
506562
:param token: Authentication token for the server (optional)
@@ -509,7 +565,7 @@ class SSEServerInfo(MCPServerInfo):
509565

510566
url: str | None = None
511567
base_url: str | None = None # deprecated
512-
token: str | None = None
568+
token: str | Secret | None = None
513569
timeout: int = 30
514570

515571
def __post_init__(self):
@@ -553,13 +609,23 @@ class StreamableHttpServerInfo(MCPServerInfo):
553609
"""
554610
Data class that encapsulates streamable HTTP MCP server connection parameters.
555611
612+
For authentication tokens containing sensitive data, you can use Secret objects
613+
for secure handling and serialization:
614+
615+
```python
616+
server_info = StreamableHttpServerInfo(
617+
url="https://my-mcp-server.com",
618+
token=Secret.from_env_var("API_KEY"),
619+
)
620+
```
621+
556622
:param url: Full URL of the MCP server (streamable HTTP endpoint)
557623
:param token: Authentication token for the server (optional)
558624
:param timeout: Connection timeout in seconds
559625
"""
560626

561627
url: str
562-
token: str | None = None
628+
token: str | Secret | None = None
563629
timeout: int = 30
564630

565631
def __post_init__(self):
@@ -585,11 +651,29 @@ class StdioServerInfo(MCPServerInfo):
585651
:param command: Command to run (e.g., "python", "node")
586652
:param args: Arguments to pass to the command
587653
:param env: Environment variables for the command
654+
655+
For environment variables containing sensitive data, you can use Secret objects
656+
for secure handling and serialization:
657+
658+
```python
659+
server_info = StdioServerInfo(
660+
command="uv",
661+
args=["run", "my-mcp-server"],
662+
env={
663+
"WORKSPACE_PATH": "/path/to/workspace", # Plain string
664+
"API_KEY": Secret.from_env_var("API_KEY"), # Secret object
665+
}
666+
)
667+
```
668+
669+
Secret objects will be properly serialized and deserialized without exposing
670+
the secret value, while plain strings will be preserved as-is. Use Secret objects
671+
for sensitive data that needs to be handled securely.
588672
"""
589673

590674
command: str
591675
args: list[str] | None = None
592-
env: dict[str, str] | None = None
676+
env: dict[str, str | Secret] | None = None
593677

594678
def create_client(self) -> MCPClient:
595679
"""

0 commit comments

Comments
 (0)