Skip to content

Commit e2fa0d5

Browse files
committed
Add webhook capability and implement in streamable http
1 parent 2ca2de7 commit e2fa0d5

File tree

7 files changed

+226
-11
lines changed

7 files changed

+226
-11
lines changed

src/mcp/client/session.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@ async def call_tool(
271271
arguments: dict[str, Any] | None = None,
272272
read_timeout_seconds: timedelta | None = None,
273273
progress_callback: ProgressFnT | None = None,
274+
webhooks: list[types.Webhook] | None = None,
274275
) -> types.CallToolResult:
275276
"""Send a tools/call request with optional progress callback support."""
276277

@@ -282,6 +283,7 @@ async def call_tool(
282283
name=name,
283284
arguments=arguments,
284285
),
286+
webhooks=webhooks,
285287
)
286288
),
287289
types.CallToolResult,

src/mcp/server/fastmcp/server.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
from mcp.types import Resource as MCPResource
6464
from mcp.types import ResourceTemplate as MCPResourceTemplate
6565
from mcp.types import Tool as MCPTool
66+
from mcp.types import Webhook
6667

6768
logger = get_logger(__name__)
6869

@@ -99,6 +100,7 @@ class Settings(BaseSettings, Generic[LifespanResultT]):
99100
stateless_http: bool = (
100101
False # If True, uses true stateless mode (new transport per request)
101102
)
103+
webhooks_supported: bool = False
102104

103105
# resource settings
104106
warn_on_duplicate_resources: bool = True
@@ -150,11 +152,10 @@ def __init__(
150152
self._mcp_server = MCPServer(
151153
name=name or "FastMCP",
152154
instructions=instructions,
153-
lifespan=(
154-
lifespan_wrapper(self, self.settings.lifespan)
155-
if self.settings.lifespan
156-
else default_lifespan
157-
),
155+
lifespan=lifespan_wrapper(self, self.settings.lifespan)
156+
if self.settings.lifespan
157+
else default_lifespan,
158+
webhooks_supported=self.settings.webhooks_supported,
158159
)
159160
self._tool_manager = ToolManager(
160161
tools=tools, warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools
@@ -165,6 +166,7 @@ def __init__(
165166
self._prompt_manager = PromptManager(
166167
warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts
167168
)
169+
168170
if (self.settings.auth is not None) != (auth_server_provider is not None):
169171
# TODO: after we support separate authorization servers (see
170172
# https://github.com/modelcontextprotocol/modelcontextprotocol/pull/284)
@@ -272,11 +274,19 @@ def get_context(self) -> Context[ServerSession, object]:
272274
return Context(request_context=request_context, fastmcp=self)
273275

274276
async def call_tool(
275-
self, name: str, arguments: dict[str, Any]
277+
self,
278+
name: str,
279+
arguments: dict[str, Any],
280+
webhooks: list[Webhook] | None = None,
276281
) -> Sequence[TextContent | ImageContent | EmbeddedResource]:
277282
"""Call a tool by name with arguments."""
278283
context = self.get_context()
279-
result = await self._tool_manager.call_tool(name, arguments, context=context)
284+
result = await self._tool_manager.call_tool(
285+
name,
286+
arguments,
287+
context=context,
288+
webhooks=webhooks
289+
)
280290
converted_result = _convert_to_content(result)
281291
return converted_result
282292

@@ -777,6 +787,7 @@ def streamable_http_app(self) -> Starlette:
777787
event_store=self._event_store,
778788
json_response=self.settings.json_response,
779789
stateless=self.settings.stateless_http, # Use the stateless setting
790+
webhooks_supported=self.settings.webhooks_supported, # Use the webhooks supported setting
780791
)
781792

782793
# Create the ASGI handler
@@ -929,6 +940,7 @@ def my_tool(x: int, ctx: Context) -> str:
929940

930941
_request_context: RequestContext[ServerSessionT, LifespanContextT] | None
931942
_fastmcp: FastMCP | None
943+
has_webhook: bool = False
932944

933945
def __init__(
934946
self,

src/mcp/server/fastmcp/tools/tool_manager.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from mcp.server.fastmcp.tools.base import Tool
88
from mcp.server.fastmcp.utilities.logging import get_logger
99
from mcp.shared.context import LifespanContextT
10-
from mcp.types import ToolAnnotations
10+
from mcp.types import ToolAnnotations, Webhook
1111

1212
if TYPE_CHECKING:
1313
from mcp.server.fastmcp.server import Context
@@ -66,10 +66,16 @@ async def call_tool(
6666
name: str,
6767
arguments: dict[str, Any],
6868
context: Context[ServerSessionT, LifespanContextT] | None = None,
69+
webhooks: list[Webhook] | None = None,
6970
) -> Any:
7071
"""Call a tool by name with arguments."""
7172
tool = self.get_tool(name)
7273
if not tool:
7374
raise ToolError(f"Unknown tool: {name}")
7475

76+
if context:
77+
context.has_webhook = (
78+
webhooks is not None and len(webhooks) > 0
79+
)
80+
7581
return await tool.run(arguments, context=context)

src/mcp/server/lowlevel/server.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def __init__(
132132
lifespan: Callable[
133133
[Server[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]
134134
] = lifespan,
135+
webhooks_supported: bool = False,
135136
):
136137
self.name = name
137138
self.version = version
@@ -144,6 +145,7 @@ def __init__(
144145
}
145146
self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {}
146147
self.notification_options = NotificationOptions()
148+
self.webhooks_supported = webhooks_supported
147149
logger.debug(f"Initializing server '{name}'")
148150

149151
def create_initialization_options(
@@ -199,7 +201,8 @@ def get_capabilities(
199201
# Set tool capabilities if handler exists
200202
if types.ListToolsRequest in self.request_handlers:
201203
tools_capability = types.ToolsCapability(
202-
listChanged=notification_options.tools_changed
204+
listChanged=notification_options.tools_changed,
205+
webhooksSupported=self.webhooks_supported,
203206
)
204207

205208
# Set logging capabilities if handler exists
@@ -409,7 +412,7 @@ def decorator(
409412

410413
async def handler(req: types.CallToolRequest):
411414
try:
412-
results = await func(req.params.name, (req.params.arguments or {}))
415+
results = await func(req.params.name, (req.params.arguments or {}), req.webhooks)
413416
return types.ServerResult(
414417
types.CallToolResult(content=list(results), isError=False)
415418
)

src/mcp/server/streamable_http.py

Lines changed: 165 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
responses, with streaming support for long-running operations.
88
"""
99

10+
import asyncio
1011
import json
1112
import logging
1213
import re
@@ -24,6 +25,7 @@
2425
from starlette.responses import Response
2526
from starlette.types import Receive, Scope, Send
2627

28+
from mcp.shared._httpx_utils import create_mcp_http_client
2729
from mcp.shared.message import ServerMessageMetadata, SessionMessage
2830
from mcp.types import (
2931
INTERNAL_ERROR,
@@ -36,6 +38,7 @@
3638
JSONRPCRequest,
3739
JSONRPCResponse,
3840
RequestId,
41+
Webhook,
3942
)
4043

4144
logger = logging.getLogger(__name__)
@@ -136,6 +139,7 @@ def __init__(
136139
self,
137140
mcp_session_id: str | None,
138141
is_json_response_enabled: bool = False,
142+
is_webhooks_supported: bool = False,
139143
event_store: EventStore | None = None,
140144
) -> None:
141145
"""
@@ -146,6 +150,10 @@ def __init__(
146150
Must contain only visible ASCII characters (0x21-0x7E).
147151
is_json_response_enabled: If True, return JSON responses for requests
148152
instead of SSE streams. Default is False.
153+
is_webhooks_supported: If True and if webhooks are provided in
154+
tools/call request, the client will receive an Accepted
155+
HTTP response and the CallTool response will be sent to
156+
the webhook. Default is False.
149157
event_store: Event store for resumability support. If provided,
150158
resumability will be enabled, allowing clients to
151159
reconnect and resume messages.
@@ -162,6 +170,7 @@ def __init__(
162170

163171
self.mcp_session_id = mcp_session_id
164172
self.is_json_response_enabled = is_json_response_enabled
173+
self.is_webhooks_supported = is_webhooks_supported
165174
self._event_store = event_store
166175
self._request_streams: dict[
167176
RequestId,
@@ -410,9 +419,43 @@ async def _handle_post_request(
410419
](0)
411420
request_stream_reader = self._request_streams[request_id][1]
412421

422+
session_message = SessionMessage(message)
423+
if self._is_call_tool_request_with_webhooks(
424+
session_message.message
425+
):
426+
if self.is_webhooks_supported:
427+
response = self._create_json_response(
428+
JSONRPCMessage(root=JSONRPCResponse(
429+
jsonrpc="2.0",
430+
id=message.root.id,
431+
result={
432+
'content': [{
433+
'type': 'text',
434+
'text': 'Response will be forwarded to the webhook.'
435+
}],
436+
'isError': False
437+
},
438+
)),
439+
HTTPStatus.OK,
440+
)
441+
asyncio.create_task(
442+
self._send_response_to_webhooks(
443+
request_id, session_message, request_stream_reader
444+
)
445+
)
446+
else:
447+
logger.exception("Webhooks not supported error")
448+
err = "Webhooks not supported"
449+
response = self._create_error_response(
450+
f"Validation error: {err}",
451+
HTTPStatus.BAD_REQUEST,
452+
INVALID_PARAMS,
453+
)
454+
await response(scope, receive, send)
455+
return
456+
413457
if self.is_json_response_enabled:
414458
# Process the message
415-
session_message = SessionMessage(message)
416459
await writer.send(session_message)
417460
try:
418461
# Process messages from the request-specific stream
@@ -531,6 +574,115 @@ async def sse_writer():
531574
await writer.send(Exception(err))
532575
return
533576

577+
578+
async def _send_response_to_webhooks(
579+
self,
580+
request_id: str,
581+
session_message: SessionMessage,
582+
request_stream_reader: MemoryObjectReceiveStream[EventMessage],
583+
):
584+
webhooks: list[Webhook] = [Webhook(**webhook) for webhook in session_message.message.root.webhooks]
585+
writer = self._read_stream_writer
586+
if writer is None:
587+
raise ValueError(
588+
"No read stream writer available. Ensure connect() is called first."
589+
)
590+
await writer.send(session_message)
591+
592+
try:
593+
response_message = JSONRPCError(
594+
jsonrpc="2.0",
595+
id="server-error", # We don't have a request ID for general errors
596+
error=ErrorData(
597+
code=INTERNAL_ERROR,
598+
message="Error processing request: No response received",
599+
),
600+
)
601+
602+
if self.is_json_response_enabled:
603+
# Process messages from the request-specific stream
604+
# We need to collect all messages until we get a response
605+
async for event_message in request_stream_reader:
606+
# If it's a response, this is what we're waiting for
607+
if isinstance(
608+
event_message.message.root, JSONRPCResponse | JSONRPCError
609+
):
610+
response_message = event_message.message
611+
break
612+
# For notifications and request, keep waiting
613+
else:
614+
logger.debug(
615+
f"received: {event_message.message.root.method}"
616+
)
617+
618+
await self._send_message_to_webhooks(webhooks, response_message)
619+
else:
620+
# Send each event on the request stream as a separate message
621+
async for event_message in request_stream_reader:
622+
event_data = self._create_event_data(event_message)
623+
await self._send_message_to_webhooks(webhooks, event_data)
624+
625+
# If response, remove from pending streams and close
626+
if isinstance(
627+
event_message.message.root,
628+
JSONRPCResponse | JSONRPCError,
629+
):
630+
break
631+
632+
except Exception as e:
633+
logger.exception(f"Error sending response to webhooks: {e}")
634+
635+
finally:
636+
await self._clean_up_memory_streams(request_id)
637+
638+
639+
async def _send_message_to_webhooks(
640+
self,
641+
webhooks: list[Webhook],
642+
message: JSONRPCMessage | JSONRPCError | dict[str, str],
643+
):
644+
for webhook in webhooks:
645+
headers = {"Content-Type": CONTENT_TYPE_JSON}
646+
# Add authorization headers
647+
if webhook.authentication and webhook.authentication.credentials:
648+
if webhook.authentication.strategy == "bearer":
649+
headers["Authorization"] = f"Bearer {webhook.authentication.credentials}"
650+
elif webhook.authentication.strategy == "apiKey":
651+
headers["X-API-Key"] = webhook.authentication.credentials
652+
elif webhook.authentication.strategy == "basic":
653+
try:
654+
# Try to parse as JSON
655+
creds_dict = json.loads(webhook.authentication.credentials)
656+
if "username" in creds_dict and "password" in creds_dict:
657+
# Create basic auth header from username and password
658+
import base64
659+
auth_string = f"{creds_dict['username']}:{creds_dict['password']}"
660+
credentials = base64.b64encode(auth_string.encode()).decode()
661+
headers["Authorization"] = f"Basic {credentials}"
662+
except:
663+
# Not JSON, use as-is
664+
headers["Authorization"] = f"Basic {webhook.authentication.credentials}"
665+
elif webhook.authentication.strategy == "customHeader" and webhook.authentication.credentials:
666+
try:
667+
custom_headers = json.loads(webhook.authentication.credentials)
668+
headers.update(custom_headers)
669+
except:
670+
pass
671+
672+
async with create_mcp_http_client(headers=headers) as client:
673+
try:
674+
if isinstance(message, JSONRPCMessage | JSONRPCError):
675+
await client.post(
676+
webhook.url,
677+
json=message.model_dump_json(by_alias=True, exclude_none=True),
678+
)
679+
else:
680+
await client.post(webhook.url, json=message)
681+
682+
except Exception as e:
683+
logger.exception(f"Error sending response to webhook {webhook.url}: {e}")
684+
685+
534686
async def _handle_get_request(self, request: Request, send: Send) -> None:
535687
"""
536688
Handle GET request to establish SSE.
@@ -651,6 +803,18 @@ async def _handle_delete_request(self, request: Request, send: Send) -> None:
651803
)
652804
await response(request.scope, request.receive, send)
653805

806+
807+
def _is_call_tool_request_with_webhooks(self, message: JSONRPCMessage) -> bool:
808+
"""Check if the request is a call tool request with webhooks."""
809+
return (
810+
isinstance(message.root, JSONRPCRequest)
811+
and message.root.method == "tools/call"
812+
and hasattr(message.root, "webhooks")
813+
and message.root.webhooks is not None
814+
and len(message.root.webhooks) > 0
815+
)
816+
817+
654818
async def _terminate_session(self) -> None:
655819
"""Terminate the current session, closing all streams.
656820

0 commit comments

Comments
 (0)