Skip to content

Commit 0d651d5

Browse files
dsp-antClaude
and
Claude
committed
refactor: move MessageFrame class to types.py for better code organization
🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent ad1fbb0 commit 0d651d5

File tree

12 files changed

+133
-104
lines changed

12 files changed

+133
-104
lines changed

src/mcp/client/sse.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@
1010

1111
import mcp.types as types
1212
from mcp.shared.session import (
13-
MessageFrame,
1413
ReadStream,
1514
ReadStreamWriter,
1615
WriteStream,
1716
WriteStreamReader,
1817
)
18+
from mcp.types import MessageFrame
1919

2020
logger = logging.getLogger(__name__)
2121

@@ -91,7 +91,7 @@ async def sse_reader(
9191
case "message":
9292
try:
9393
message = MessageFrame(
94-
types.JSONRPCMessage.model_validate_json( # noqa: E501
94+
root=types.JSONRPCMessage.model_validate_json( # noqa: E501
9595
sse.data
9696
),
9797
raw=sse,

src/mcp/server/sse.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,12 @@ async def handle_sse(request):
4646

4747
import mcp.types as types
4848
from mcp.shared.session import (
49-
MessageFrame,
5049
ReadStream,
5150
ReadStreamWriter,
5251
WriteStream,
5352
WriteStreamReader,
5453
)
54+
from mcp.types import MessageFrame
5555

5656
logger = logging.getLogger(__name__)
5757

@@ -176,4 +176,4 @@ async def handle_post_message(
176176
logger.debug(f"Sending message to writer: {message}")
177177
response = Response("Accepted", status_code=202)
178178
await response(scope, receive, send)
179-
await writer.send(MessageFrame(message, raw=request))
179+
await writer.send(MessageFrame(root=message, raw=request))

src/mcp/server/stdio.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@ async def run_server():
2727

2828
import mcp.types as types
2929
from mcp.shared.session import (
30-
MessageFrame,
3130
ReadStream,
3231
ReadStreamWriter,
3332
WriteStream,
3433
WriteStreamReader,
3534
)
35+
from mcp.types import MessageFrame
3636

3737

3838
@asynccontextmanager
@@ -72,14 +72,15 @@ async def stdin_reader():
7272
await read_stream_writer.send(exc)
7373
continue
7474

75-
await read_stream_writer.send(MessageFrame(message, raw=line))
75+
await read_stream_writer.send(MessageFrame(root=message, raw=line))
7676
except anyio.ClosedResourceError:
7777
await anyio.lowlevel.checkpoint()
7878

7979
async def stdout_writer():
8080
try:
8181
async with write_stream_reader:
8282
async for message in write_stream_reader:
83+
# Extract the inner JSONRPCRequest/JSONRPCResponse from MessageFrame
8384
json = message.model_dump_json(by_alias=True, exclude_none=True)
8485
await stdout.write(json + "\n")
8586
await stdout.flush()

src/mcp/server/websocket.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77

88
import mcp.types as types
99
from mcp.shared.session import (
10-
MessageFrame,
1110
ReadStream,
1211
ReadStreamWriter,
1312
WriteStream,
1413
WriteStreamReader,
1514
)
15+
from mcp.types import MessageFrame
1616

1717
logger = logging.getLogger(__name__)
1818

@@ -47,7 +47,7 @@ async def ws_reader():
4747
continue
4848

4949
await read_stream_writer.send(
50-
MessageFrame(client_message, raw=message)
50+
MessageFrame(root=client_message, raw=message)
5151
)
5252
except anyio.ClosedResourceError:
5353
await websocket.close()

src/mcp/shared/memory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from mcp.client.session import ClientSession, ListRootsFnT, SamplingFnT
1313
from mcp.server import Server
14-
from mcp.shared.session import MessageFrame
14+
from mcp.types import MessageFrame
1515

1616
MessageStream = tuple[
1717
MemoryObjectReceiveStream[MessageFrame | Exception],

src/mcp/shared/session.py

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import anyio.lowlevel
88
import httpx
99
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
10-
from pydantic import BaseModel, RootModel
10+
from pydantic import BaseModel
1111

1212
from mcp.shared.exceptions import McpError
1313
from mcp.types import (
@@ -21,27 +21,17 @@
2121
JSONRPCNotification,
2222
JSONRPCRequest,
2323
JSONRPCResponse,
24+
MessageFrame,
2425
RequestParams,
2526
ServerNotification,
2627
ServerRequest,
2728
ServerResult,
2829
)
2930

30-
RawT = TypeVar("RawT")
31-
32-
33-
class MessageFrame(RootModel[JSONRPCMessage], Generic[RawT]):
34-
root: JSONRPCMessage
35-
raw: RawT | None = None
36-
37-
class Config:
38-
arbitrary_types_allowed = True
39-
40-
41-
ReadStream = MemoryObjectReceiveStream[MessageFrame[RawT] | Exception]
42-
ReadStreamWriter = MemoryObjectSendStream[MessageFrame[RawT] | Exception]
43-
WriteStream = MemoryObjectSendStream[MessageFrame[RawT]]
44-
WriteStreamReader = MemoryObjectReceiveStream[MessageFrame[RawT]]
31+
ReadStream = MemoryObjectReceiveStream[MessageFrame | Exception]
32+
ReadStreamWriter = MemoryObjectSendStream[MessageFrame | Exception]
33+
WriteStream = MemoryObjectSendStream[MessageFrame]
34+
WriteStreamReader = MemoryObjectReceiveStream[MessageFrame]
4535

4636
SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest)
4737
SendResultT = TypeVar("SendResultT", ClientResult, ServerResult)
@@ -242,7 +232,7 @@ async def send_request(
242232
# TODO: Support progress callbacks
243233

244234
await self._write_stream.send(
245-
MessageFrame(JSONRPCMessage(jsonrpc_request), None)
235+
MessageFrame(root=JSONRPCMessage(jsonrpc_request), raw=None)
246236
)
247237

248238
try:
@@ -280,15 +270,17 @@ async def send_notification(self, notification: SendNotificationT) -> None:
280270
)
281271

282272
await self._write_stream.send(
283-
MessageFrame(JSONRPCMessage(jsonrpc_notification))
273+
MessageFrame(root=JSONRPCMessage(jsonrpc_notification), raw=None)
284274
)
285275

286276
async def _send_response(
287277
self, request_id: RequestId, response: SendResultT | ErrorData
288278
) -> None:
289279
if isinstance(response, ErrorData):
290280
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
291-
await self._write_stream.send(MessageFrame(JSONRPCMessage(jsonrpc_error)))
281+
await self._write_stream.send(
282+
MessageFrame(root=JSONRPCMessage(jsonrpc_error), raw=None)
283+
)
292284
else:
293285
jsonrpc_response = JSONRPCResponse(
294286
jsonrpc="2.0",
@@ -298,7 +290,7 @@ async def _send_response(
298290
),
299291
)
300292
await self._write_stream.send(
301-
MessageFrame(JSONRPCMessage(jsonrpc_response))
293+
MessageFrame(root=JSONRPCMessage(jsonrpc_response), raw=None)
302294
)
303295

304296
async def _receive_loop(self) -> None:

src/mcp/types.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,21 @@ class JSONRPCMessage(
180180
pass
181181

182182

183+
RawT = TypeVar("RawT")
184+
185+
186+
class MessageFrame(BaseModel, Generic[RawT]):
187+
root: JSONRPCMessage
188+
raw: RawT | None = None
189+
model_config = ConfigDict(extra="allow")
190+
191+
def model_dump(self, *args, **kwargs):
192+
return self.root.model_dump(*args, **kwargs)
193+
194+
def model_dump_json(self, *args, **kwargs):
195+
return self.root.model_dump_json(*args, **kwargs)
196+
197+
183198
class EmptyResult(Result):
184199
"""A response that indicates success but carries no data."""
185200

tests/client/test_session.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import pytest
33

44
from mcp.client.session import ClientSession
5-
from mcp.shared.session import MessageFrame
65
from mcp.types import (
76
LATEST_PROTOCOL_VERSION,
87
ClientNotification,
@@ -12,8 +11,8 @@
1211
InitializeRequest,
1312
InitializeResult,
1413
JSONRPCMessage,
15-
JSONRPCRequest,
1614
JSONRPCResponse,
15+
MessageFrame,
1716
ServerCapabilities,
1817
ServerResult,
1918
)
@@ -34,7 +33,7 @@ async def mock_server():
3433
nonlocal initialized_notification
3534

3635
jsonrpc_request = await client_to_server_receive.receive()
37-
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
36+
assert isinstance(jsonrpc_request, MessageFrame)
3837
request = ClientRequest.model_validate(
3938
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
4039
)
@@ -61,7 +60,7 @@ async def mock_server():
6160
root=JSONRPCMessage(
6261
JSONRPCResponse(
6362
jsonrpc="2.0",
64-
id=jsonrpc_request.root.id,
63+
id=jsonrpc_request.root.root.id,
6564
result=result.model_dump(
6665
by_alias=True, mode="json", exclude_none=True
6766
),
@@ -71,7 +70,7 @@ async def mock_server():
7170
)
7271
)
7372
jsonrpc_notification = await client_to_server_receive.receive()
74-
assert isinstance(jsonrpc_notification.root, MessageFrame)
73+
assert isinstance(jsonrpc_notification.root, JSONRPCMessage)
7574
initialized_notification = ClientNotification.model_validate(
7675
jsonrpc_notification.root.model_dump(
7776
by_alias=True, mode="json", exclude_none=True

tests/issues/test_192_request_id.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
JSONRPCMessage,
1212
JSONRPCNotification,
1313
JSONRPCRequest,
14+
MessageFrame,
1415
NotificationParams,
1516
)
1617

@@ -58,7 +59,9 @@ async def run_server():
5859
jsonrpc="2.0",
5960
)
6061

61-
await client_writer.send(JSONRPCMessage(root=init_req))
62+
await client_writer.send(
63+
MessageFrame(root=JSONRPCMessage(root=init_req), raw=None)
64+
)
6265
await server_reader.receive() # Get init response but don't need to check it
6366

6467
# Send initialized notification
@@ -67,21 +70,25 @@ async def run_server():
6770
params=NotificationParams().model_dump(by_alias=True, exclude_none=True),
6871
jsonrpc="2.0",
6972
)
70-
await client_writer.send(JSONRPCMessage(root=initialized_notification))
73+
await client_writer.send(
74+
MessageFrame(root=JSONRPCMessage(root=initialized_notification), raw=None)
75+
)
7176

7277
# Send ping request with custom ID
7378
ping_request = JSONRPCRequest(
7479
id=custom_request_id, method="ping", params={}, jsonrpc="2.0"
7580
)
7681

77-
await client_writer.send(JSONRPCMessage(root=ping_request))
82+
await client_writer.send(
83+
MessageFrame(root=JSONRPCMessage(root=ping_request), raw=None)
84+
)
7885

7986
# Read response
8087
response = await server_reader.receive()
8188

8289
# Verify response ID matches request ID
8390
assert (
84-
response.root.id == custom_request_id
91+
response.root.root.id == custom_request_id
8592
), "Response ID should match request ID"
8693

8794
# Cancel server task

0 commit comments

Comments
 (0)