Skip to content

Commit ee28ad8

Browse files
committed
improve event store example implementation
1 parent 5757f6c commit ee28ad8

File tree

2 files changed

+115
-83
lines changed

2 files changed

+115
-83
lines changed

examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/event_store.py

Lines changed: 68 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -6,75 +6,100 @@
66
"""
77

88
import logging
9-
import time
10-
from collections.abc import Awaitable, Callable
11-
from operator import itemgetter
9+
from collections import deque
10+
from dataclasses import dataclass
1211
from uuid import uuid4
1312

14-
from mcp.server.streamable_http import EventId, EventStore, StreamId
13+
from mcp.server.streamable_http import (
14+
EventCallback,
15+
EventId,
16+
EventMessage,
17+
EventStore,
18+
StreamId,
19+
)
1520
from mcp.types import JSONRPCMessage
1621

1722
logger = logging.getLogger(__name__)
1823

1924

25+
@dataclass
26+
class EventEntry:
27+
"""
28+
Represents an event entry in the event store.
29+
"""
30+
31+
event_id: EventId
32+
stream_id: StreamId
33+
message: JSONRPCMessage
34+
35+
2036
class InMemoryEventStore(EventStore):
2137
"""
2238
Simple in-memory implementation of the EventStore interface for resumability.
2339
This is primarily intended for examples and testing, not for production use
2440
where a persistent storage solution would be more appropriate.
41+
42+
This implementation keeps only the last N events per stream for memory efficiency.
2543
"""
2644

27-
def __init__(self):
28-
self.events: dict[
29-
str, tuple[str, JSONRPCMessage, float]
30-
] = {} # event_id -> (stream_id, message, timestamp)
45+
def __init__(self, max_events_per_stream: int = 100):
46+
"""Initialize the event store.
47+
48+
Args:
49+
max_events_per_stream: Maximum number of events to keep per stream
50+
"""
51+
self.max_events_per_stream = max_events_per_stream
52+
# for maintaining last N events per stream
53+
self.streams: dict[StreamId, deque[EventEntry]] = {}
54+
# event_id -> EventEntry for quick lookup
55+
self.event_index: dict[EventId, EventEntry] = {}
3156

3257
async def store_event(
3358
self, stream_id: StreamId, message: JSONRPCMessage
3459
) -> EventId:
3560
"""Stores an event with a generated event ID."""
3661
event_id = str(uuid4())
37-
self.events[event_id] = (stream_id, message, time.time())
62+
event_entry = EventEntry(
63+
event_id=event_id, stream_id=stream_id, message=message
64+
)
65+
66+
# Get or create deque for this stream
67+
if stream_id not in self.streams:
68+
self.streams[stream_id] = deque(maxlen=self.max_events_per_stream)
69+
70+
# If deque is full, the oldest event will be automatically removed
71+
# We need to remove it from the event_index as well
72+
if len(self.streams[stream_id]) == self.max_events_per_stream:
73+
oldest_event = self.streams[stream_id][0]
74+
self.event_index.pop(oldest_event.event_id, None)
75+
76+
# Add new event
77+
self.streams[stream_id].append(event_entry)
78+
self.event_index[event_id] = event_entry
79+
3880
return event_id
3981

4082
async def replay_events_after(
4183
self,
4284
last_event_id: EventId,
43-
send_callback: Callable[[EventId, JSONRPCMessage], Awaitable[None]],
44-
) -> StreamId:
85+
send_callback: EventCallback,
86+
) -> StreamId | None:
4587
"""Replays events that occurred after the specified event ID."""
46-
logger.debug(f"Attempting to replay events after {last_event_id}")
47-
logger.debug(f"Total events in store: {len(self.events)}")
48-
logger.debug(f"Event IDs in store: {list(self.events.keys())}")
49-
50-
if not last_event_id or last_event_id not in self.events:
88+
if last_event_id not in self.event_index:
5189
logger.warning(f"Event ID {last_event_id} not found in store")
52-
return ""
53-
54-
# Get the stream ID and timestamp from the last event
55-
stream_id, _, last_timestamp = self.events[last_event_id]
56-
57-
# Find all events for this stream after the last event
58-
events_sorted = sorted(
59-
[
60-
(event_id, message, timestamp)
61-
for event_id, (sid, message, timestamp) in self.events.items()
62-
if sid == stream_id and timestamp > last_timestamp
63-
],
64-
key=itemgetter(2),
65-
)
66-
67-
events_to_replay = [
68-
(event_id, message) for event_id, message, _ in events_sorted
69-
]
70-
71-
logger.debug(f"Found {len(events_to_replay)} events to replay")
72-
logger.debug(
73-
f"Events to replay: {[event_id for event_id, _ in events_to_replay]}"
74-
)
75-
76-
# Send all events in order
77-
for event_id, message in events_to_replay:
78-
await send_callback(event_id, message)
90+
return None
91+
92+
# Get the stream and find events after the last one
93+
last_event = self.event_index[last_event_id]
94+
stream_id = last_event.stream_id
95+
stream_events = self.streams.get(last_event.stream_id, deque())
96+
97+
# Events in deque are already in chronological order
98+
found_last = False
99+
for event in stream_events:
100+
if found_last:
101+
await send_callback(EventMessage(event.message, event.event_id))
102+
elif event.event_id == last_event_id:
103+
found_last = True
79104

80105
return stream_id

src/mcp/server/streamable_http.py

Lines changed: 47 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,22 @@
6363
EventId = str
6464

6565

66+
class EventMessage:
67+
"""
68+
A JSONRPCMessage with an optional event ID for stream resumability.
69+
"""
70+
71+
message: JSONRPCMessage
72+
event_id: str | None
73+
74+
def __init__(self, message: JSONRPCMessage, event_id: str | None = None):
75+
self.message = message
76+
self.event_id = event_id
77+
78+
79+
EventCallback = Callable[[EventMessage], Awaitable[None]]
80+
81+
6682
class EventStore(ABC):
6783
"""
6884
Interface for resumability support via event storage.
@@ -88,8 +104,8 @@ async def store_event(
88104
async def replay_events_after(
89105
self,
90106
last_event_id: EventId,
91-
send_callback: Callable[[EventId, JSONRPCMessage], Awaitable[None]],
92-
) -> StreamId:
107+
send_callback: EventCallback,
108+
) -> StreamId | None:
93109
"""
94110
Replays events that occurred after the specified event ID.
95111
@@ -149,7 +165,7 @@ def __init__(
149165
self.is_json_response_enabled = is_json_response_enabled
150166
self._event_store = event_store
151167
self._request_streams: dict[
152-
RequestId, MemoryObjectSendStream[tuple[JSONRPCMessage, str | None]]
168+
RequestId, MemoryObjectSendStream[EventMessage]
153169
] = {}
154170
self._terminated = False
155171

@@ -358,7 +374,7 @@ async def _handle_post_request(
358374
request_id = str(message.root.id)
359375
# Create promise stream for getting response
360376
request_stream_writer, request_stream_reader = (
361-
anyio.create_memory_object_stream[tuple[JSONRPCMessage, str | None]](0)
377+
anyio.create_memory_object_stream[EventMessage](0)
362378
)
363379

364380
# Register this stream for the request ID
@@ -373,16 +389,18 @@ async def _handle_post_request(
373389
response_message = None
374390

375391
# Use similar approach to SSE writer for consistency
376-
async for received_message, _ in request_stream_reader:
392+
async for event_message in request_stream_reader:
377393
# If it's a response, this is what we're waiting for
378394
if isinstance(
379-
received_message.root, JSONRPCResponse | JSONRPCError
395+
event_message.message.root, JSONRPCResponse | JSONRPCError
380396
):
381-
response_message = received_message
397+
response_message = event_message.message
382398
break
383399
# For notifications and request, keep waiting
384400
else:
385-
logger.debug(f"received: {received_message.root.method}")
401+
logger.debug(
402+
f"received: {event_message.message.root.method}"
403+
)
386404

387405
# At this point we should have a response
388406
if response_message:
@@ -424,27 +442,24 @@ async def sse_writer():
424442
try:
425443
async with sse_stream_writer, request_stream_reader:
426444
# Process messages from the request-specific stream
427-
async for (
428-
received_message,
429-
event_id,
430-
) in request_stream_reader:
445+
async for event_message in request_stream_reader:
431446
# Build the event data
432447
event_data = {
433448
"event": "message",
434-
"data": received_message.model_dump_json(
449+
"data": event_message.message.model_dump_json(
435450
by_alias=True, exclude_none=True
436451
),
437452
}
438453

439454
# If an event ID was provided, include it
440-
if event_id:
441-
event_data["id"] = event_id
455+
if event_message.event_id:
456+
event_data["id"] = event_message.event_id
442457

443458
await sse_stream_writer.send(event_data)
444459

445460
# If response, remove from pending streams and close
446461
if isinstance(
447-
received_message.root,
462+
event_message.message.root,
448463
JSONRPCResponse | JSONRPCError,
449464
):
450465
if request_id:
@@ -563,20 +578,15 @@ async def standalone_sse_writer():
563578
try:
564579
# Create a standalone message stream for server-initiated messages
565580
standalone_stream_writer, standalone_stream_reader = (
566-
anyio.create_memory_object_stream[
567-
tuple[JSONRPCMessage, str | None]
568-
](0)
581+
anyio.create_memory_object_stream[EventMessage](0)
569582
)
570583

571584
# Register this stream using the special key
572585
self._request_streams[GET_STREAM_KEY] = standalone_stream_writer
573586

574587
async with sse_stream_writer, standalone_stream_reader:
575588
# Process messages from the standalone stream
576-
async for item in standalone_stream_reader:
577-
# The message router always sends a tuple of (message, event_id)
578-
received_message, event_id = item
579-
589+
async for event_message in standalone_stream_reader:
580590
# For the standalone stream, we handle:
581591
# - JSONRPCNotification (server sends notifications to client)
582592
# - JSONRPCRequest (server sends requests to client)
@@ -585,14 +595,14 @@ async def standalone_sse_writer():
585595
# Send the message via SSE
586596
event_data = {
587597
"event": "message",
588-
"data": received_message.model_dump_json(
598+
"data": event_message.message.model_dump_json(
589599
by_alias=True, exclude_none=True
590600
),
591601
}
592602

593603
# If an event ID was provided, include it in the SSE stream
594-
if event_id:
595-
event_data["id"] = event_id
604+
if event_message.event_id:
605+
event_data["id"] = event_message.event_id
596606

597607
await sse_stream_writer.send(event_data)
598608
except Exception as e:
@@ -741,14 +751,12 @@ async def replay_sender():
741751
try:
742752
async with sse_stream_writer:
743753
# Define an async callback for sending events
744-
async def send_event(
745-
event_id: EventId, message: JSONRPCMessage
746-
) -> None:
754+
async def send_event(event_message: EventMessage) -> None:
747755
await sse_stream_writer.send(
748756
{
749757
"event": "message",
750-
"id": event_id,
751-
"data": message.model_dump_json(
758+
"id": event_message.event_id,
759+
"data": event_message.message.model_dump_json(
752760
by_alias=True, exclude_none=True
753761
),
754762
}
@@ -762,22 +770,21 @@ async def send_event(
762770
# If stream ID not in mapping, create it
763771
if stream_id and stream_id not in self._request_streams:
764772
msg_writer, msg_reader = anyio.create_memory_object_stream[
765-
tuple[JSONRPCMessage, str | None]
773+
EventMessage
766774
](0)
767775
self._request_streams[stream_id] = msg_writer
768776

769777
# Forward messages to SSE
770778
async with msg_reader:
771-
async for item in msg_reader:
772-
message, event_id = item
773-
779+
async for event_message in msg_reader:
780+
event_data = event_message.message.model_dump_json(
781+
by_alias=True, exclude_none=True
782+
)
774783
await sse_stream_writer.send(
775784
{
776785
"event": "message",
777-
"id": event_id,
778-
"data": message.model_dump_json(
779-
by_alias=True, exclude_none=True
780-
),
786+
"id": event_message.event_id,
787+
"data": event_data,
781788
}
782789
)
783790
except Exception as e:
@@ -871,7 +878,7 @@ async def message_router():
871878
try:
872879
# Send both the message and the event ID
873880
await self._request_streams[request_stream_id].send(
874-
(message, event_id)
881+
EventMessage(message, event_id)
875882
)
876883
except (
877884
anyio.BrokenResourceError,

0 commit comments

Comments
 (0)