Skip to content

Commit 7e7674e

Browse files
committed
fix: Pass cursor parameter to server
1 parent 6353dd1 commit 7e7674e

File tree

4 files changed

+320
-69
lines changed

4 files changed

+320
-69
lines changed

src/mcp/client/session.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ async def list_resources(
209209
types.ClientRequest(
210210
types.ListResourcesRequest(
211211
method="resources/list",
212-
cursor=cursor,
212+
params=types.PaginatedRequestParams(cursor=cursor),
213213
)
214214
),
215215
types.ListResourcesResult,
@@ -223,7 +223,7 @@ async def list_resource_templates(
223223
types.ClientRequest(
224224
types.ListResourceTemplatesRequest(
225225
method="resources/templates/list",
226-
cursor=cursor,
226+
params=types.PaginatedRequestParams(cursor=cursor),
227227
)
228228
),
229229
types.ListResourceTemplatesResult,
@@ -295,7 +295,7 @@ async def list_prompts(self, cursor: str | None = None) -> types.ListPromptsResu
295295
types.ClientRequest(
296296
types.ListPromptsRequest(
297297
method="prompts/list",
298-
cursor=cursor,
298+
params=types.PaginatedRequestParams(cursor=cursor),
299299
)
300300
),
301301
types.ListPromptsResult,
@@ -340,7 +340,7 @@ async def list_tools(self, cursor: str | None = None) -> types.ListToolsResult:
340340
types.ClientRequest(
341341
types.ListToolsRequest(
342342
method="tools/list",
343-
cursor=cursor,
343+
params=types.PaginatedRequestParams(cursor=cursor),
344344
)
345345
),
346346
types.ListToolsResult,

src/mcp/types.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,14 @@ class Meta(BaseModel):
5353
meta: Meta | None = Field(alias="_meta", default=None)
5454

5555

56+
class PaginatedRequestParams(RequestParams):
57+
cursor: Cursor | None = None
58+
"""
59+
An opaque token representing the current pagination position.
60+
If provided, the server should return results starting after this cursor.
61+
"""
62+
63+
5664
class NotificationParams(BaseModel):
5765
class Meta(BaseModel):
5866
model_config = ConfigDict(extra="allow")
@@ -79,14 +87,6 @@ class Request(BaseModel, Generic[RequestParamsT, MethodT]):
7987
model_config = ConfigDict(extra="allow")
8088

8189

82-
class PaginatedRequest(Request[RequestParamsT, MethodT]):
83-
cursor: Cursor | None = None
84-
"""
85-
An opaque token representing the current pagination position.
86-
If provided, the server should return results starting after this cursor.
87-
"""
88-
89-
9090
class Notification(BaseModel, Generic[NotificationParamsT, MethodT]):
9191
"""Base class for JSON-RPC notifications."""
9292

@@ -359,12 +359,12 @@ class ProgressNotification(
359359

360360

361361
class ListResourcesRequest(
362-
PaginatedRequest[RequestParams | None, Literal["resources/list"]]
362+
Request[PaginatedRequestParams | None, Literal["resources/list"]]
363363
):
364364
"""Sent from the client to request a list of resources the server has."""
365365

366366
method: Literal["resources/list"]
367-
params: RequestParams | None = None
367+
params: PaginatedRequestParams | None = None
368368

369369

370370
class Annotations(BaseModel):
@@ -423,12 +423,12 @@ class ListResourcesResult(PaginatedResult):
423423

424424

425425
class ListResourceTemplatesRequest(
426-
PaginatedRequest[RequestParams | None, Literal["resources/templates/list"]]
426+
Request[PaginatedRequestParams | None, Literal["resources/templates/list"]]
427427
):
428428
"""Sent from the client to request a list of resource templates the server has."""
429429

430430
method: Literal["resources/templates/list"]
431-
params: RequestParams | None = None
431+
params: PaginatedRequestParams | None = None
432432

433433

434434
class ListResourceTemplatesResult(PaginatedResult):
@@ -571,12 +571,12 @@ class ResourceUpdatedNotification(
571571

572572

573573
class ListPromptsRequest(
574-
PaginatedRequest[RequestParams | None, Literal["prompts/list"]]
574+
Request[PaginatedRequestParams | None, Literal["prompts/list"]]
575575
):
576576
"""Sent from the client to request a list of prompts and prompt templates."""
577577

578578
method: Literal["prompts/list"]
579-
params: RequestParams | None = None
579+
params: PaginatedRequestParams | None = None
580580

581581

582582
class PromptArgument(BaseModel):
@@ -703,11 +703,11 @@ class PromptListChangedNotification(
703703
params: NotificationParams | None = None
704704

705705

706-
class ListToolsRequest(PaginatedRequest[RequestParams | None, Literal["tools/list"]]):
706+
class ListToolsRequest(Request[PaginatedRequestParams | None, Literal["tools/list"]]):
707707
"""Sent from the client to request a list of tools the server has."""
708708

709709
method: Literal["tools/list"]
710-
params: RequestParams | None = None
710+
params: PaginatedRequestParams | None = None
711711

712712

713713
class ToolAnnotations(BaseModel):

tests/client/conftest.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
from contextlib import asynccontextmanager
2+
from unittest.mock import patch
3+
4+
import pytest
5+
6+
import mcp.shared.memory
7+
from mcp.shared.message import SessionMessage
8+
from mcp.types import (
9+
JSONRPCNotification,
10+
JSONRPCRequest,
11+
)
12+
13+
14+
class SpyMemoryObjectSendStream:
15+
def __init__(self, original_stream):
16+
self.original_stream = original_stream
17+
self.sent_messages: list[SessionMessage] = []
18+
19+
async def send(self, message):
20+
self.sent_messages.append(message)
21+
await self.original_stream.send(message)
22+
23+
async def aclose(self):
24+
await self.original_stream.aclose()
25+
26+
async def __aenter__(self):
27+
return self
28+
29+
async def __aexit__(self, *args):
30+
await self.aclose()
31+
32+
33+
class StreamSpyCollection:
34+
def __init__(
35+
self,
36+
client_spy: SpyMemoryObjectSendStream,
37+
server_spy: SpyMemoryObjectSendStream,
38+
):
39+
self.client = client_spy
40+
self.server = server_spy
41+
42+
def clear(self) -> None:
43+
"""Clear all captured messages."""
44+
self.client.sent_messages.clear()
45+
self.server.sent_messages.clear()
46+
47+
def get_client_requests(self, method: str | None = None) -> list[JSONRPCRequest]:
48+
"""Get client-sent requests, optionally filtered by method."""
49+
return [
50+
req.message.root
51+
for req in self.client.sent_messages
52+
if isinstance(req.message.root, JSONRPCRequest)
53+
and (method is None or req.message.root.method == method)
54+
]
55+
56+
def get_server_requests(self, method: str | None = None) -> list[JSONRPCRequest]:
57+
"""Get server-sent requests, optionally filtered by method."""
58+
return [
59+
req.message.root
60+
for req in self.server.sent_messages
61+
if isinstance(req.message.root, JSONRPCRequest)
62+
and (method is None or req.message.root.method == method)
63+
]
64+
65+
def get_client_notifications(
66+
self, method: str | None = None
67+
) -> list[JSONRPCNotification]:
68+
"""Get client-sent notifications, optionally filtered by method."""
69+
return [
70+
notif.message.root
71+
for notif in self.client.sent_messages
72+
if isinstance(notif.message.root, JSONRPCNotification)
73+
and (method is None or notif.message.root.method == method)
74+
]
75+
76+
def get_server_notifications(
77+
self, method: str | None = None
78+
) -> list[JSONRPCNotification]:
79+
"""Get server-sent notifications, optionally filtered by method."""
80+
return [
81+
notif.message.root
82+
for notif in self.server.sent_messages
83+
if isinstance(notif.message.root, JSONRPCNotification)
84+
and (method is None or notif.message.root.method == method)
85+
]
86+
87+
88+
@pytest.fixture
89+
def stream_spy():
90+
"""Fixture that provides spies for both client and server write streams.
91+
92+
Example usage:
93+
async def test_something(stream_spy):
94+
# ... set up server and client ...
95+
96+
spies = stream_spy()
97+
98+
# Run some operation that sends messages
99+
await client.some_operation()
100+
101+
# Check the messages
102+
requests = spies.get_client_requests(method="some/method")
103+
assert len(requests) == 1
104+
105+
# Clear for the next operation
106+
spies.clear()
107+
"""
108+
client_spy = None
109+
server_spy = None
110+
111+
# Store references to our spy objects
112+
def capture_spies(c_spy, s_spy):
113+
nonlocal client_spy, server_spy
114+
client_spy = c_spy
115+
server_spy = s_spy
116+
117+
# Create patched version of stream creation
118+
original_create_streams = mcp.shared.memory.create_client_server_memory_streams
119+
120+
@asynccontextmanager
121+
async def patched_create_streams():
122+
async with original_create_streams() as (client_streams, server_streams):
123+
client_read, client_write = client_streams
124+
server_read, server_write = server_streams
125+
126+
# Create spy wrappers
127+
spy_client_write = SpyMemoryObjectSendStream(client_write)
128+
spy_server_write = SpyMemoryObjectSendStream(server_write)
129+
130+
# Capture references for the test to use
131+
capture_spies(spy_client_write, spy_server_write)
132+
133+
yield (client_read, spy_client_write), (server_read, spy_server_write)
134+
135+
# Apply the patch for the duration of the test
136+
with patch(
137+
"mcp.shared.memory.create_client_server_memory_streams", patched_create_streams
138+
):
139+
# Return a collection with helper methods
140+
yield lambda: StreamSpyCollection(client_spy, server_spy)

0 commit comments

Comments
 (0)