Skip to content

Add message queue for SSE messages POST endpoint #459

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
10c5af8
initial
akash329d Apr 8, 2025
c3d5efc
readme update
akash329d Apr 8, 2025
b2fce7d
ruff
akash329d Apr 8, 2025
7c82f36
fix typing issues
akash329d Apr 9, 2025
b92f22f
update lock
akash329d Apr 9, 2025
5dbca6e
retrigger tests?
akash329d Apr 9, 2025
badc1e2
revert
akash329d Apr 9, 2025
23665db
clean up test stuff
akash329d Apr 9, 2025
ccd5a13
lock pydantic version
akash329d Apr 9, 2025
fb44020
fix lock
akash329d Apr 9, 2025
efe6da9
wip
akash329d Apr 14, 2025
d625782
fixes
akash329d Apr 14, 2025
78c6aef
Add optional redis dep
akash329d Apr 14, 2025
fad836c
changes
akash329d Apr 14, 2025
fd97501
format / lint
akash329d Apr 14, 2025
4bce7d8
cleanup
akash329d Apr 14, 2025
d6075bb
update lock
akash329d Apr 14, 2025
8ee3a7e
remove redundant comment
akash329d Apr 14, 2025
7cabcea
add a checkpoint
akash329d Apr 14, 2025
5111c92
naming changes
akash329d Apr 15, 2025
09e0cab
logging improvements
akash329d Apr 15, 2025
8d280d8
better channel validation
akash329d Apr 15, 2025
c2bb049
merge
akash329d Apr 15, 2025
87e07b8
formatting and linting
akash329d Apr 15, 2025
b484284
fix naming in server.py
akash329d Apr 15, 2025
0bfd800
Rework to fix POST blocking issue
akash329d Apr 21, 2025
1e81f36
comments fix
akash329d Apr 21, 2025
215cc42
wip
akash329d Apr 22, 2025
8fce8e6
back to b48428486aa90f7529c36e5a78074ac2a2d813bc
akash329d Apr 22, 2025
b2893e6
push message handling onto corresponding SSE session task group
akash329d Apr 22, 2025
e5938d4
format
akash329d Apr 22, 2025
a151f1c
clean up comment and session state
akash329d Apr 22, 2025
d22f46b
shorten comment
akash329d Apr 22, 2025
8d6a20d
remove extra change
akash329d Apr 23, 2025
bb24881
testing
akash329d Apr 24, 2025
564561f
add a cancelscope on the finally
akash329d May 1, 2025
9419ad0
Move to session heartbeat w/ TTL
akash329d May 1, 2025
046ed94
add test for TTL
akash329d May 1, 2025
70547c0
merge conflict
akash329d May 5, 2025
5638653
merge fixes
akash329d May 5, 2025
2437e46
fakeredis dev dep
akash329d May 5, 2025
9664c8a
fmt
akash329d May 5, 2025
30b475b
convert to Pydantic models
akash329d May 5, 2025
0114189
fmt
akash329d May 5, 2025
7081a40
more type fixes
akash329d May 5, 2025
5ae3cc6
test cleanup
akash329d May 5, 2025
46b78f2
rename to message dispatch
akash329d May 5, 2025
e21d514
make int tests better
akash329d May 6, 2025
ee9f4de
lint
akash329d May 6, 2025
206a98a
tests hanging
akash329d May 6, 2025
bb59e5d
do cleanup after test
akash329d May 6, 2025
ca9a54a
fmt
akash329d May 6, 2025
9832c34
clean up int test
akash329d May 6, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,30 @@ app.router.routes.append(Host('mcp.acme.corp', app=mcp.sse_app()))

For more information on mounting applications in Starlette, see the [Starlette documentation](https://www.starlette.io/routing/#submounting-routes).

#### Message Dispatch Options

By default, the SSE server uses an in-memory message dispatch system for incoming POST messages. For production deployments or distributed scenarios, you can use Redis or implement your own message dispatch system that conforms to the `MessageDispatch` protocol:

```python
# Using the built-in Redis message dispatch
from mcp.server.fastmcp import FastMCP
from mcp.server.message_queue import RedisMessageDispatch

# Create a Redis message dispatch
redis_dispatch = RedisMessageDispatch(
redis_url="redis://localhost:6379/0", prefix="mcp:pubsub:"
)

# Pass the message dispatch instance to the server
mcp = FastMCP("My App", message_queue=redis_dispatch)
```

To use Redis, add the Redis dependency:

```bash
uv add "mcp[redis]"
```

## Examples

### Echo Server
Expand Down
5 changes: 4 additions & 1 deletion examples/servers/simple-prompt/mcp_simple_prompt/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,15 @@ async def get_prompt(
)

if transport == "sse":
from mcp.server.message_queue.redis import RedisMessageDispatch
from mcp.server.sse import SseServerTransport
from starlette.applications import Starlette
from starlette.responses import Response
from starlette.routing import Mount, Route

sse = SseServerTransport("/messages/")
message_dispatch = RedisMessageDispatch("redis://localhost:6379/0")

sse = SseServerTransport("/messages/", message_dispatch=message_dispatch)

async def handle_sse(request):
async with sse.connect_sse(
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ dependencies = [
rich = ["rich>=13.9.4"]
cli = ["typer>=0.12.4", "python-dotenv>=1.0.0"]
ws = ["websockets>=15.0.1"]
redis = ["redis>=5.2.1", "types-redis>=4.6.0.20241004"]

[project.scripts]
mcp = "mcp.cli:app [cli]"
Expand All @@ -55,6 +56,7 @@ dev = [
"pytest-xdist>=3.6.1",
"pytest-examples>=0.0.14",
"pytest-pretty>=1.2.0",
"fakeredis==2.28.1",
]
docs = [
"mkdocs>=1.6.1",
Expand Down
6 changes: 5 additions & 1 deletion src/mcp/client/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,9 @@ async def sse_reader(
await read_stream_writer.send(exc)
continue

session_message = SessionMessage(message)
session_message = SessionMessage(
message=message
)
await read_stream_writer.send(session_message)
case _:
logger.warning(
Expand Down Expand Up @@ -148,3 +150,5 @@ async def post_writer(endpoint_url: str):
finally:
await read_stream_writer.aclose()
await write_stream.aclose()
await read_stream.aclose()
await write_stream_reader.aclose()
2 changes: 1 addition & 1 deletion src/mcp/client/stdio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ async def stdout_reader():
await read_stream_writer.send(exc)
continue

session_message = SessionMessage(message)
session_message = SessionMessage(message=message)
await read_stream_writer.send(session_message)
except anyio.ClosedResourceError:
await anyio.lowlevel.checkpoint()
Expand Down
6 changes: 3 additions & 3 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ async def _handle_sse_event(
):
message.root.id = original_request_id

session_message = SessionMessage(message)
session_message = SessionMessage(message=message)
await read_stream_writer.send(session_message)

# Call resumption token callback if we have an ID
Expand Down Expand Up @@ -286,7 +286,7 @@ async def _handle_json_response(
try:
content = await response.aread()
message = JSONRPCMessage.model_validate_json(content)
session_message = SessionMessage(message)
session_message = SessionMessage(message=message)
await read_stream_writer.send(session_message)
except Exception as exc:
logger.error(f"Error parsing JSON response: {exc}")
Expand Down Expand Up @@ -333,7 +333,7 @@ async def _send_session_terminated_error(
id=request_id,
error=ErrorData(code=32600, message="Session terminated"),
)
session_message = SessionMessage(JSONRPCMessage(jsonrpc_error))
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error))
await read_stream_writer.send(session_message)

async def post_writer(
Expand Down
2 changes: 1 addition & 1 deletion src/mcp/client/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ async def ws_reader():
async for raw_text in ws:
try:
message = types.JSONRPCMessage.model_validate_json(raw_text)
session_message = SessionMessage(message)
session_message = SessionMessage(message=message)
await read_stream_writer.send(session_message)
except ValidationError as exc:
# If JSON parse or model validation fails, send the exception
Expand Down
31 changes: 28 additions & 3 deletions src/mcp/server/fastmcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from mcp.server.lowlevel.server import LifespanResultT
from mcp.server.lowlevel.server import Server as MCPServer
from mcp.server.lowlevel.server import lifespan as default_lifespan
from mcp.server.message_queue import MessageDispatch
from mcp.server.session import ServerSession, ServerSessionT
from mcp.server.sse import SseServerTransport
from mcp.server.stdio import stdio_server
Expand Down Expand Up @@ -90,6 +91,11 @@ class Settings(BaseSettings, Generic[LifespanResultT]):
sse_path: str = "/sse"
message_path: str = "/messages/"

# SSE message queue settings
message_dispatch: MessageDispatch | None = Field(
None, description="Custom message dispatch instance"
)

# resource settings
warn_on_duplicate_resources: bool = True

Expand Down Expand Up @@ -569,12 +575,21 @@ async def run_sse_async(self) -> None:

def sse_app(self) -> Starlette:
"""Return an instance of the SSE server app."""
message_dispatch = self.settings.message_dispatch
if message_dispatch is None:
from mcp.server.message_queue import InMemoryMessageDispatch

message_dispatch = InMemoryMessageDispatch()
logger.info("Using default in-memory message dispatch")

from starlette.middleware import Middleware
from starlette.routing import Mount, Route

# Set up auth context and dependencies

sse = SseServerTransport(self.settings.message_path)
sse = SseServerTransport(
self.settings.message_path, message_dispatch=message_dispatch
)

async def handle_sse(scope: Scope, receive: Receive, send: Send):
# Add client ID from auth context into request context if available
Expand All @@ -589,7 +604,14 @@ async def handle_sse(scope: Scope, receive: Receive, send: Send):
streams[1],
self._mcp_server.create_initialization_options(),
)
return Response()
return Response()

@asynccontextmanager
async def lifespan(app: Starlette):
try:
yield
finally:
await message_dispatch.close()

# Create routes
routes: list[Route | Mount] = []
Expand Down Expand Up @@ -666,7 +688,10 @@ async def sse_endpoint(request: Request) -> None:

# Create Starlette app with routes and middleware
return Starlette(
debug=self.settings.debug, routes=routes, middleware=middleware
debug=self.settings.debug,
routes=routes,
middleware=middleware,
lifespan=lifespan,
)

async def list_prompts(self) -> list[MCPPrompt]:
Expand Down
16 changes: 16 additions & 0 deletions src/mcp/server/message_queue/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""
Message Dispatch Module for MCP Server

This module implements dispatch interfaces for handling
messages between clients and servers.
"""

from mcp.server.message_queue.base import InMemoryMessageDispatch, MessageDispatch

# Try to import Redis implementation if available
try:
from mcp.server.message_queue.redis import RedisMessageDispatch
except ImportError:
RedisMessageDispatch = None

__all__ = ["MessageDispatch", "InMemoryMessageDispatch", "RedisMessageDispatch"]
116 changes: 116 additions & 0 deletions src/mcp/server/message_queue/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import logging
from collections.abc import Awaitable, Callable
from contextlib import asynccontextmanager
from typing import Protocol, runtime_checkable
from uuid import UUID

from pydantic import ValidationError

from mcp.shared.message import SessionMessage

logger = logging.getLogger(__name__)

MessageCallback = Callable[[SessionMessage | Exception], Awaitable[None]]


@runtime_checkable
class MessageDispatch(Protocol):
"""Abstract interface for SSE message dispatching.

This interface allows messages to be published to sessions and callbacks to be
registered for message handling, enabling multiple servers to handle requests.
"""

async def publish_message(
self, session_id: UUID, message: SessionMessage | str
) -> bool:
"""Publish a message for the specified session.

Args:
session_id: The UUID of the session this message is for
message: The message to publish (SessionMessage or str for invalid JSON)

Returns:
bool: True if message was published, False if session not found
"""
...

@asynccontextmanager
async def subscribe(self, session_id: UUID, callback: MessageCallback):
"""Request-scoped context manager that subscribes to messages for a session.

Args:
session_id: The UUID of the session to subscribe to
callback: Async callback function to handle messages for this session
"""
yield

async def session_exists(self, session_id: UUID) -> bool:
"""Check if a session exists.

Args:
session_id: The UUID of the session to check

Returns:
bool: True if the session is active, False otherwise
"""
...

async def close(self) -> None:
"""Close the message dispatch."""
...


class InMemoryMessageDispatch:
"""Default in-memory implementation of the MessageDispatch interface.

This implementation immediately dispatches messages to registered callbacks when
messages are received without any queuing behavior.
"""

def __init__(self) -> None:
self._callbacks: dict[UUID, MessageCallback] = {}

async def publish_message(
self, session_id: UUID, message: SessionMessage | str
) -> bool:
"""Publish a message for the specified session."""
if session_id not in self._callbacks:
logger.warning(f"Message dropped: unknown session {session_id}")
return False

# Parse string messages or recreate original ValidationError
if isinstance(message, str):
try:
callback_argument = SessionMessage.model_validate_json(message)
except ValidationError as exc:
callback_argument = exc
else:
callback_argument = message

# Call the callback with either valid message or recreated ValidationError
await self._callbacks[session_id](callback_argument)

logger.debug(f"Message dispatched to session {session_id}")
return True

@asynccontextmanager
async def subscribe(self, session_id: UUID, callback: MessageCallback):
"""Request-scoped context manager that subscribes to messages for a session."""
self._callbacks[session_id] = callback
logger.debug(f"Subscribing to messages for session {session_id}")

try:
yield
finally:
if session_id in self._callbacks:
del self._callbacks[session_id]
logger.debug(f"Unsubscribed from session {session_id}")

async def session_exists(self, session_id: UUID) -> bool:
"""Check if a session exists."""
return session_id in self._callbacks

async def close(self) -> None:
"""Close the message dispatch."""
pass
Loading
Loading