Skip to content

Commit e80c015

Browse files
authored
fix: Pass cursor parameter to server (#745)
1 parent 2ca2de7 commit e80c015

File tree

5 files changed

+306
-73
lines changed

5 files changed

+306
-73
lines changed

src/mcp/client/session.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,9 @@ async def list_resources(
209209
types.ClientRequest(
210210
types.ListResourcesRequest(
211211
method="resources/list",
212-
cursor=cursor,
212+
params=types.PaginatedRequestParams(cursor=cursor)
213+
if cursor is not None
214+
else None,
213215
)
214216
),
215217
types.ListResourcesResult,
@@ -223,7 +225,9 @@ async def list_resource_templates(
223225
types.ClientRequest(
224226
types.ListResourceTemplatesRequest(
225227
method="resources/templates/list",
226-
cursor=cursor,
228+
params=types.PaginatedRequestParams(cursor=cursor)
229+
if cursor is not None
230+
else None,
227231
)
228232
),
229233
types.ListResourceTemplatesResult,
@@ -295,7 +299,9 @@ async def list_prompts(self, cursor: str | None = None) -> types.ListPromptsResu
295299
types.ClientRequest(
296300
types.ListPromptsRequest(
297301
method="prompts/list",
298-
cursor=cursor,
302+
params=types.PaginatedRequestParams(cursor=cursor)
303+
if cursor is not None
304+
else None,
299305
)
300306
),
301307
types.ListPromptsResult,
@@ -340,7 +346,9 @@ async def list_tools(self, cursor: str | None = None) -> types.ListToolsResult:
340346
types.ClientRequest(
341347
types.ListToolsRequest(
342348
method="tools/list",
343-
cursor=cursor,
349+
params=types.PaginatedRequestParams(cursor=cursor)
350+
if cursor is not None
351+
else None,
344352
)
345353
),
346354
types.ListToolsResult,

src/mcp/types.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,14 @@ class Meta(BaseModel):
5353
meta: Meta | None = Field(alias="_meta", default=None)
5454

5555

56+
class PaginatedRequestParams(RequestParams):
57+
cursor: Cursor | None = None
58+
"""
59+
An opaque token representing the current pagination position.
60+
If provided, the server should return results starting after this cursor.
61+
"""
62+
63+
5664
class NotificationParams(BaseModel):
5765
class Meta(BaseModel):
5866
model_config = ConfigDict(extra="allow")
@@ -79,12 +87,13 @@ class Request(BaseModel, Generic[RequestParamsT, MethodT]):
7987
model_config = ConfigDict(extra="allow")
8088

8189

82-
class PaginatedRequest(Request[RequestParamsT, MethodT]):
83-
cursor: Cursor | None = None
84-
"""
85-
An opaque token representing the current pagination position.
86-
If provided, the server should return results starting after this cursor.
87-
"""
90+
class PaginatedRequest(
91+
Request[PaginatedRequestParams | None, MethodT], Generic[MethodT]
92+
):
93+
"""Base class for paginated requests,
94+
matching the schema's PaginatedRequest interface."""
95+
96+
params: PaginatedRequestParams | None = None
8897

8998

9099
class Notification(BaseModel, Generic[NotificationParamsT, MethodT]):
@@ -358,13 +367,10 @@ class ProgressNotification(
358367
params: ProgressNotificationParams
359368

360369

361-
class ListResourcesRequest(
362-
PaginatedRequest[RequestParams | None, Literal["resources/list"]]
363-
):
370+
class ListResourcesRequest(PaginatedRequest[Literal["resources/list"]]):
364371
"""Sent from the client to request a list of resources the server has."""
365372

366373
method: Literal["resources/list"]
367-
params: RequestParams | None = None
368374

369375

370376
class Annotations(BaseModel):
@@ -423,12 +429,11 @@ class ListResourcesResult(PaginatedResult):
423429

424430

425431
class ListResourceTemplatesRequest(
426-
PaginatedRequest[RequestParams | None, Literal["resources/templates/list"]]
432+
PaginatedRequest[Literal["resources/templates/list"]]
427433
):
428434
"""Sent from the client to request a list of resource templates the server has."""
429435

430436
method: Literal["resources/templates/list"]
431-
params: RequestParams | None = None
432437

433438

434439
class ListResourceTemplatesResult(PaginatedResult):
@@ -570,13 +575,10 @@ class ResourceUpdatedNotification(
570575
params: ResourceUpdatedNotificationParams
571576

572577

573-
class ListPromptsRequest(
574-
PaginatedRequest[RequestParams | None, Literal["prompts/list"]]
575-
):
578+
class ListPromptsRequest(PaginatedRequest[Literal["prompts/list"]]):
576579
"""Sent from the client to request a list of prompts and prompt templates."""
577580

578581
method: Literal["prompts/list"]
579-
params: RequestParams | None = None
580582

581583

582584
class PromptArgument(BaseModel):
@@ -703,11 +705,10 @@ class PromptListChangedNotification(
703705
params: NotificationParams | None = None
704706

705707

706-
class ListToolsRequest(PaginatedRequest[RequestParams | None, Literal["tools/list"]]):
708+
class ListToolsRequest(PaginatedRequest[Literal["tools/list"]]):
707709
"""Sent from the client to request a list of tools the server has."""
708710

709711
method: Literal["tools/list"]
710-
params: RequestParams | None = None
711712

712713

713714
class ToolAnnotations(BaseModel):
@@ -741,7 +742,7 @@ class ToolAnnotations(BaseModel):
741742

742743
idempotentHint: bool | None = None
743744
"""
744-
If true, calling the tool repeatedly with the same arguments
745+
If true, calling the tool repeatedly with the same arguments
745746
will have no additional effect on the its environment.
746747
(This property is meaningful only when `readOnlyHint == false`)
747748
Default: false

tests/client/conftest.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
from contextlib import asynccontextmanager
2+
from unittest.mock import patch
3+
4+
import pytest
5+
6+
import mcp.shared.memory
7+
from mcp.shared.message import SessionMessage
8+
from mcp.types import (
9+
JSONRPCNotification,
10+
JSONRPCRequest,
11+
)
12+
13+
14+
class SpyMemoryObjectSendStream:
15+
def __init__(self, original_stream):
16+
self.original_stream = original_stream
17+
self.sent_messages: list[SessionMessage] = []
18+
19+
async def send(self, message):
20+
self.sent_messages.append(message)
21+
await self.original_stream.send(message)
22+
23+
async def aclose(self):
24+
await self.original_stream.aclose()
25+
26+
async def __aenter__(self):
27+
return self
28+
29+
async def __aexit__(self, *args):
30+
await self.aclose()
31+
32+
33+
class StreamSpyCollection:
34+
def __init__(
35+
self,
36+
client_spy: SpyMemoryObjectSendStream,
37+
server_spy: SpyMemoryObjectSendStream,
38+
):
39+
self.client = client_spy
40+
self.server = server_spy
41+
42+
def clear(self) -> None:
43+
"""Clear all captured messages."""
44+
self.client.sent_messages.clear()
45+
self.server.sent_messages.clear()
46+
47+
def get_client_requests(self, method: str | None = None) -> list[JSONRPCRequest]:
48+
"""Get client-sent requests, optionally filtered by method."""
49+
return [
50+
req.message.root
51+
for req in self.client.sent_messages
52+
if isinstance(req.message.root, JSONRPCRequest)
53+
and (method is None or req.message.root.method == method)
54+
]
55+
56+
def get_server_requests(self, method: str | None = None) -> list[JSONRPCRequest]:
57+
"""Get server-sent requests, optionally filtered by method."""
58+
return [
59+
req.message.root
60+
for req in self.server.sent_messages
61+
if isinstance(req.message.root, JSONRPCRequest)
62+
and (method is None or req.message.root.method == method)
63+
]
64+
65+
def get_client_notifications(
66+
self, method: str | None = None
67+
) -> list[JSONRPCNotification]:
68+
"""Get client-sent notifications, optionally filtered by method."""
69+
return [
70+
notif.message.root
71+
for notif in self.client.sent_messages
72+
if isinstance(notif.message.root, JSONRPCNotification)
73+
and (method is None or notif.message.root.method == method)
74+
]
75+
76+
def get_server_notifications(
77+
self, method: str | None = None
78+
) -> list[JSONRPCNotification]:
79+
"""Get server-sent notifications, optionally filtered by method."""
80+
return [
81+
notif.message.root
82+
for notif in self.server.sent_messages
83+
if isinstance(notif.message.root, JSONRPCNotification)
84+
and (method is None or notif.message.root.method == method)
85+
]
86+
87+
88+
@pytest.fixture
89+
def stream_spy():
90+
"""Fixture that provides spies for both client and server write streams.
91+
92+
Example usage:
93+
async def test_something(stream_spy):
94+
# ... set up server and client ...
95+
96+
spies = stream_spy()
97+
98+
# Run some operation that sends messages
99+
await client.some_operation()
100+
101+
# Check the messages
102+
requests = spies.get_client_requests(method="some/method")
103+
assert len(requests) == 1
104+
105+
# Clear for the next operation
106+
spies.clear()
107+
"""
108+
client_spy = None
109+
server_spy = None
110+
111+
# Store references to our spy objects
112+
def capture_spies(c_spy, s_spy):
113+
nonlocal client_spy, server_spy
114+
client_spy = c_spy
115+
server_spy = s_spy
116+
117+
# Create patched version of stream creation
118+
original_create_streams = mcp.shared.memory.create_client_server_memory_streams
119+
120+
@asynccontextmanager
121+
async def patched_create_streams():
122+
async with original_create_streams() as (client_streams, server_streams):
123+
client_read, client_write = client_streams
124+
server_read, server_write = server_streams
125+
126+
# Create spy wrappers
127+
spy_client_write = SpyMemoryObjectSendStream(client_write)
128+
spy_server_write = SpyMemoryObjectSendStream(server_write)
129+
130+
# Capture references for the test to use
131+
capture_spies(spy_client_write, spy_server_write)
132+
133+
yield (client_read, spy_client_write), (server_read, spy_server_write)
134+
135+
# Apply the patch for the duration of the test
136+
with patch(
137+
"mcp.shared.memory.create_client_server_memory_streams", patched_create_streams
138+
):
139+
# Return a collection with helper methods
140+
def get_spy_collection() -> StreamSpyCollection:
141+
assert client_spy is not None, "client_spy was not initialized"
142+
assert server_spy is not None, "server_spy was not initialized"
143+
return StreamSpyCollection(client_spy, server_spy)
144+
145+
yield get_spy_collection

0 commit comments

Comments
 (0)