13
13
from abc import ABC , abstractmethod
14
14
from collections .abc import AsyncGenerator , Awaitable , Callable
15
15
from contextlib import asynccontextmanager
16
+ from dataclasses import dataclass
16
17
from http import HTTPStatus
17
- from typing import Any
18
18
19
19
import anyio
20
20
from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
63
63
EventId = str
64
64
65
65
66
+ @dataclass
66
67
class EventMessage :
67
68
"""
68
69
A JSONRPCMessage with an optional event ID for stream resumability.
69
70
"""
70
71
71
72
message : JSONRPCMessage
72
- event_id : str | None
73
-
74
- def __init__ (self , message : JSONRPCMessage , event_id : str | None = None ):
75
- self .message = message
76
- self .event_id = event_id
73
+ event_id : str | None = None
77
74
78
75
79
76
EventCallback = Callable [[EventMessage ], Awaitable [None ]]
@@ -226,6 +223,21 @@ def _get_session_id(self, request: Request) -> str | None:
226
223
"""Extract the session ID from request headers."""
227
224
return request .headers .get (MCP_SESSION_ID_HEADER )
228
225
226
+ def _create_event_data (self , event_message : EventMessage ) -> dict [str , str ]:
227
+ """Create event data dictionary from an EventMessage."""
228
+ event_data = {
229
+ "event" : "message" ,
230
+ "data" : event_message .message .model_dump_json (
231
+ by_alias = True , exclude_none = True
232
+ ),
233
+ }
234
+
235
+ # If an event ID was provided, include it
236
+ if event_message .event_id :
237
+ event_data ["id" ] = event_message .event_id
238
+
239
+ return event_data
240
+
229
241
async def handle_request (self , scope : Scope , receive : Receive , send : Send ) -> None :
230
242
"""Application entry point that handles all HTTP requests"""
231
243
request = Request (scope , receive )
@@ -434,7 +446,7 @@ async def _handle_post_request(
434
446
else :
435
447
# Create SSE stream
436
448
sse_stream_writer , sse_stream_reader = (
437
- anyio .create_memory_object_stream [dict [str , Any ]](0 )
449
+ anyio .create_memory_object_stream [dict [str , str ]](0 )
438
450
)
439
451
440
452
async def sse_writer ():
@@ -444,17 +456,7 @@ async def sse_writer():
444
456
# Process messages from the request-specific stream
445
457
async for event_message in request_stream_reader :
446
458
# Build the event data
447
- event_data = {
448
- "event" : "message" ,
449
- "data" : event_message .message .model_dump_json (
450
- by_alias = True , exclude_none = True
451
- ),
452
- }
453
-
454
- # If an event ID was provided, include it
455
- if event_message .event_id :
456
- event_data ["id" ] = event_message .event_id
457
-
459
+ event_data = self ._create_event_data (event_message )
458
460
await sse_stream_writer .send (event_data )
459
461
460
462
# If response, remove from pending streams and close
@@ -571,7 +573,7 @@ async def _handle_get_request(self, request: Request, send: Send) -> None:
571
573
572
574
# Create SSE stream
573
575
sse_stream_writer , sse_stream_reader = anyio .create_memory_object_stream [
574
- dict [str , Any ]
576
+ dict [str , str ]
575
577
](0 )
576
578
577
579
async def standalone_sse_writer ():
@@ -593,17 +595,7 @@ async def standalone_sse_writer():
593
595
# We should NOT receive JSONRPCResponse
594
596
595
597
# Send the message via SSE
596
- event_data = {
597
- "event" : "message" ,
598
- "data" : event_message .message .model_dump_json (
599
- by_alias = True , exclude_none = True
600
- ),
601
- }
602
-
603
- # If an event ID was provided, include it in the SSE stream
604
- if event_message .event_id :
605
- event_data ["id" ] = event_message .event_id
606
-
598
+ event_data = self ._create_event_data (event_message )
607
599
await sse_stream_writer .send (event_data )
608
600
except Exception as e :
609
601
logger .exception (f"Error in standalone SSE writer: { e } " )
@@ -744,23 +736,16 @@ async def _replay_events(
744
736
745
737
# Create SSE stream for replay
746
738
sse_stream_writer , sse_stream_reader = anyio .create_memory_object_stream [
747
- dict [str , Any ]
739
+ dict [str , str ]
748
740
](0 )
749
741
750
742
async def replay_sender ():
751
743
try :
752
744
async with sse_stream_writer :
753
745
# Define an async callback for sending events
754
746
async def send_event (event_message : EventMessage ) -> None :
755
- await sse_stream_writer .send (
756
- {
757
- "event" : "message" ,
758
- "id" : event_message .event_id ,
759
- "data" : event_message .message .model_dump_json (
760
- by_alias = True , exclude_none = True
761
- ),
762
- }
763
- )
747
+ event_data = self ._create_event_data (event_message )
748
+ await sse_stream_writer .send (event_data )
764
749
765
750
# Replay past events and get the stream ID
766
751
stream_id = await event_store .replay_events_after (
@@ -777,16 +762,9 @@ async def send_event(event_message: EventMessage) -> None:
777
762
# Forward messages to SSE
778
763
async with msg_reader :
779
764
async for event_message in msg_reader :
780
- event_data = event_message .message .model_dump_json (
781
- by_alias = True , exclude_none = True
782
- )
783
- await sse_stream_writer .send (
784
- {
785
- "event" : "message" ,
786
- "id" : event_message .event_id ,
787
- "data" : event_data ,
788
- }
789
- )
765
+ event_data = self ._create_event_data (event_message )
766
+
767
+ await sse_stream_writer .send (event_data )
790
768
except Exception as e :
791
769
logger .exception (f"Error in replay sender: { e } " )
792
770
0 commit comments