Skip to content

Commit 5dbddeb

Browse files
committed
refactor _create_event_data
1 parent ee28ad8 commit 5dbddeb

File tree

1 file changed

+28
-50
lines changed

1 file changed

+28
-50
lines changed

src/mcp/server/streamable_http.py

Lines changed: 28 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
from abc import ABC, abstractmethod
1414
from collections.abc import AsyncGenerator, Awaitable, Callable
1515
from contextlib import asynccontextmanager
16+
from dataclasses import dataclass
1617
from http import HTTPStatus
17-
from typing import Any
1818

1919
import anyio
2020
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
@@ -63,17 +63,14 @@
6363
EventId = str
6464

6565

66+
@dataclass
6667
class EventMessage:
6768
"""
6869
A JSONRPCMessage with an optional event ID for stream resumability.
6970
"""
7071

7172
message: JSONRPCMessage
72-
event_id: str | None
73-
74-
def __init__(self, message: JSONRPCMessage, event_id: str | None = None):
75-
self.message = message
76-
self.event_id = event_id
73+
event_id: str | None = None
7774

7875

7976
EventCallback = Callable[[EventMessage], Awaitable[None]]
@@ -226,6 +223,21 @@ def _get_session_id(self, request: Request) -> str | None:
226223
"""Extract the session ID from request headers."""
227224
return request.headers.get(MCP_SESSION_ID_HEADER)
228225

226+
def _create_event_data(self, event_message: EventMessage) -> dict[str, str]:
227+
"""Create event data dictionary from an EventMessage."""
228+
event_data = {
229+
"event": "message",
230+
"data": event_message.message.model_dump_json(
231+
by_alias=True, exclude_none=True
232+
),
233+
}
234+
235+
# If an event ID was provided, include it
236+
if event_message.event_id:
237+
event_data["id"] = event_message.event_id
238+
239+
return event_data
240+
229241
async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None:
230242
"""Application entry point that handles all HTTP requests"""
231243
request = Request(scope, receive)
@@ -434,7 +446,7 @@ async def _handle_post_request(
434446
else:
435447
# Create SSE stream
436448
sse_stream_writer, sse_stream_reader = (
437-
anyio.create_memory_object_stream[dict[str, Any]](0)
449+
anyio.create_memory_object_stream[dict[str, str]](0)
438450
)
439451

440452
async def sse_writer():
@@ -444,17 +456,7 @@ async def sse_writer():
444456
# Process messages from the request-specific stream
445457
async for event_message in request_stream_reader:
446458
# Build the event data
447-
event_data = {
448-
"event": "message",
449-
"data": event_message.message.model_dump_json(
450-
by_alias=True, exclude_none=True
451-
),
452-
}
453-
454-
# If an event ID was provided, include it
455-
if event_message.event_id:
456-
event_data["id"] = event_message.event_id
457-
459+
event_data = self._create_event_data(event_message)
458460
await sse_stream_writer.send(event_data)
459461

460462
# If response, remove from pending streams and close
@@ -571,7 +573,7 @@ async def _handle_get_request(self, request: Request, send: Send) -> None:
571573

572574
# Create SSE stream
573575
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[
574-
dict[str, Any]
576+
dict[str, str]
575577
](0)
576578

577579
async def standalone_sse_writer():
@@ -593,17 +595,7 @@ async def standalone_sse_writer():
593595
# We should NOT receive JSONRPCResponse
594596

595597
# Send the message via SSE
596-
event_data = {
597-
"event": "message",
598-
"data": event_message.message.model_dump_json(
599-
by_alias=True, exclude_none=True
600-
),
601-
}
602-
603-
# If an event ID was provided, include it in the SSE stream
604-
if event_message.event_id:
605-
event_data["id"] = event_message.event_id
606-
598+
event_data = self._create_event_data(event_message)
607599
await sse_stream_writer.send(event_data)
608600
except Exception as e:
609601
logger.exception(f"Error in standalone SSE writer: {e}")
@@ -744,23 +736,16 @@ async def _replay_events(
744736

745737
# Create SSE stream for replay
746738
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[
747-
dict[str, Any]
739+
dict[str, str]
748740
](0)
749741

750742
async def replay_sender():
751743
try:
752744
async with sse_stream_writer:
753745
# Define an async callback for sending events
754746
async def send_event(event_message: EventMessage) -> None:
755-
await sse_stream_writer.send(
756-
{
757-
"event": "message",
758-
"id": event_message.event_id,
759-
"data": event_message.message.model_dump_json(
760-
by_alias=True, exclude_none=True
761-
),
762-
}
763-
)
747+
event_data = self._create_event_data(event_message)
748+
await sse_stream_writer.send(event_data)
764749

765750
# Replay past events and get the stream ID
766751
stream_id = await event_store.replay_events_after(
@@ -777,16 +762,9 @@ async def send_event(event_message: EventMessage) -> None:
777762
# Forward messages to SSE
778763
async with msg_reader:
779764
async for event_message in msg_reader:
780-
event_data = event_message.message.model_dump_json(
781-
by_alias=True, exclude_none=True
782-
)
783-
await sse_stream_writer.send(
784-
{
785-
"event": "message",
786-
"id": event_message.event_id,
787-
"data": event_data,
788-
}
789-
)
765+
event_data = self._create_event_data(event_message)
766+
767+
await sse_stream_writer.send(event_data)
790768
except Exception as e:
791769
logger.exception(f"Error in replay sender: {e}")
792770

0 commit comments

Comments
 (0)