-
Notifications
You must be signed in to change notification settings - Fork 1.5k
fix: Pass cursor parameter to server #745
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7e7674e
82f35ef
3adff13
a529f39
ba90839
bea1093
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -53,6 +53,14 @@ class Meta(BaseModel): | |
meta: Meta | None = Field(alias="_meta", default=None) | ||
|
||
|
||
class PaginatedRequestParams(RequestParams): | ||
cursor: Cursor | None = None | ||
""" | ||
An opaque token representing the current pagination position. | ||
If provided, the server should return results starting after this cursor. | ||
""" | ||
|
||
|
||
class NotificationParams(BaseModel): | ||
class Meta(BaseModel): | ||
model_config = ConfigDict(extra="allow") | ||
|
@@ -79,12 +87,13 @@ class Request(BaseModel, Generic[RequestParamsT, MethodT]): | |
model_config = ConfigDict(extra="allow") | ||
|
||
|
||
class PaginatedRequest(Request[RequestParamsT, MethodT]): | ||
cursor: Cursor | None = None | ||
""" | ||
An opaque token representing the current pagination position. | ||
If provided, the server should return results starting after this cursor. | ||
""" | ||
Comment on lines
-82
to
-87
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removed this, since it was only used for the methods in question. |
||
class PaginatedRequest( | ||
Request[PaginatedRequestParams | None, MethodT], Generic[MethodT] | ||
): | ||
"""Base class for paginated requests, | ||
matching the schema's PaginatedRequest interface.""" | ||
|
||
params: PaginatedRequestParams | None = None | ||
|
||
|
||
class Notification(BaseModel, Generic[NotificationParamsT, MethodT]): | ||
|
@@ -358,13 +367,10 @@ class ProgressNotification( | |
params: ProgressNotificationParams | ||
|
||
|
||
class ListResourcesRequest( | ||
PaginatedRequest[RequestParams | None, Literal["resources/list"]] | ||
): | ||
class ListResourcesRequest(PaginatedRequest[Literal["resources/list"]]): | ||
"""Sent from the client to request a list of resources the server has.""" | ||
|
||
method: Literal["resources/list"] | ||
params: RequestParams | None = None | ||
|
||
|
||
class Annotations(BaseModel): | ||
|
@@ -423,12 +429,11 @@ class ListResourcesResult(PaginatedResult): | |
|
||
|
||
class ListResourceTemplatesRequest( | ||
PaginatedRequest[RequestParams | None, Literal["resources/templates/list"]] | ||
PaginatedRequest[Literal["resources/templates/list"]] | ||
): | ||
"""Sent from the client to request a list of resource templates the server has.""" | ||
|
||
method: Literal["resources/templates/list"] | ||
params: RequestParams | None = None | ||
|
||
|
||
class ListResourceTemplatesResult(PaginatedResult): | ||
|
@@ -570,13 +575,10 @@ class ResourceUpdatedNotification( | |
params: ResourceUpdatedNotificationParams | ||
|
||
|
||
class ListPromptsRequest( | ||
PaginatedRequest[RequestParams | None, Literal["prompts/list"]] | ||
): | ||
class ListPromptsRequest(PaginatedRequest[Literal["prompts/list"]]): | ||
"""Sent from the client to request a list of prompts and prompt templates.""" | ||
|
||
method: Literal["prompts/list"] | ||
params: RequestParams | None = None | ||
|
||
|
||
class PromptArgument(BaseModel): | ||
|
@@ -703,11 +705,10 @@ class PromptListChangedNotification( | |
params: NotificationParams | None = None | ||
|
||
|
||
class ListToolsRequest(PaginatedRequest[RequestParams | None, Literal["tools/list"]]): | ||
class ListToolsRequest(PaginatedRequest[Literal["tools/list"]]): | ||
"""Sent from the client to request a list of tools the server has.""" | ||
|
||
method: Literal["tools/list"] | ||
params: RequestParams | None = None | ||
|
||
|
||
class ToolAnnotations(BaseModel): | ||
|
@@ -741,7 +742,7 @@ class ToolAnnotations(BaseModel): | |
|
||
idempotentHint: bool | None = None | ||
""" | ||
If true, calling the tool repeatedly with the same arguments | ||
If true, calling the tool repeatedly with the same arguments | ||
will have no additional effect on the its environment. | ||
(This property is meaningful only when `readOnlyHint == false`) | ||
Default: false | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
from contextlib import asynccontextmanager | ||
from unittest.mock import patch | ||
|
||
import pytest | ||
|
||
import mcp.shared.memory | ||
from mcp.shared.message import SessionMessage | ||
from mcp.types import ( | ||
JSONRPCNotification, | ||
JSONRPCRequest, | ||
) | ||
|
||
|
||
class SpyMemoryObjectSendStream: | ||
def __init__(self, original_stream): | ||
self.original_stream = original_stream | ||
self.sent_messages: list[SessionMessage] = [] | ||
|
||
async def send(self, message): | ||
self.sent_messages.append(message) | ||
await self.original_stream.send(message) | ||
|
||
async def aclose(self): | ||
await self.original_stream.aclose() | ||
|
||
async def __aenter__(self): | ||
return self | ||
|
||
async def __aexit__(self, *args): | ||
await self.aclose() | ||
|
||
|
||
class StreamSpyCollection: | ||
def __init__( | ||
self, | ||
client_spy: SpyMemoryObjectSendStream, | ||
server_spy: SpyMemoryObjectSendStream, | ||
): | ||
self.client = client_spy | ||
self.server = server_spy | ||
|
||
def clear(self) -> None: | ||
"""Clear all captured messages.""" | ||
self.client.sent_messages.clear() | ||
self.server.sent_messages.clear() | ||
|
||
def get_client_requests(self, method: str | None = None) -> list[JSONRPCRequest]: | ||
"""Get client-sent requests, optionally filtered by method.""" | ||
return [ | ||
req.message.root | ||
for req in self.client.sent_messages | ||
if isinstance(req.message.root, JSONRPCRequest) | ||
and (method is None or req.message.root.method == method) | ||
] | ||
|
||
def get_server_requests(self, method: str | None = None) -> list[JSONRPCRequest]: | ||
"""Get server-sent requests, optionally filtered by method.""" | ||
return [ | ||
req.message.root | ||
for req in self.server.sent_messages | ||
if isinstance(req.message.root, JSONRPCRequest) | ||
and (method is None or req.message.root.method == method) | ||
] | ||
|
||
def get_client_notifications( | ||
self, method: str | None = None | ||
) -> list[JSONRPCNotification]: | ||
"""Get client-sent notifications, optionally filtered by method.""" | ||
return [ | ||
notif.message.root | ||
for notif in self.client.sent_messages | ||
if isinstance(notif.message.root, JSONRPCNotification) | ||
and (method is None or notif.message.root.method == method) | ||
] | ||
|
||
def get_server_notifications( | ||
self, method: str | None = None | ||
) -> list[JSONRPCNotification]: | ||
"""Get server-sent notifications, optionally filtered by method.""" | ||
return [ | ||
notif.message.root | ||
for notif in self.server.sent_messages | ||
if isinstance(notif.message.root, JSONRPCNotification) | ||
and (method is None or notif.message.root.method == method) | ||
] | ||
|
||
|
||
@pytest.fixture | ||
def stream_spy(): | ||
"""Fixture that provides spies for both client and server write streams. | ||
Comment on lines
+88
to
+90
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Built this as a fixture so that it can be easily reused in multiple tests. I needed a way to verify the request payload "on the wire" and this seemed like a good approach, since it did not require any modification of the actual client code (only test fixtures). |
||
Example usage: | ||
async def test_something(stream_spy): | ||
# ... set up server and client ... | ||
spies = stream_spy() | ||
# Run some operation that sends messages | ||
await client.some_operation() | ||
# Check the messages | ||
requests = spies.get_client_requests(method="some/method") | ||
assert len(requests) == 1 | ||
# Clear for the next operation | ||
spies.clear() | ||
""" | ||
client_spy = None | ||
server_spy = None | ||
|
||
# Store references to our spy objects | ||
def capture_spies(c_spy, s_spy): | ||
nonlocal client_spy, server_spy | ||
client_spy = c_spy | ||
server_spy = s_spy | ||
|
||
# Create patched version of stream creation | ||
original_create_streams = mcp.shared.memory.create_client_server_memory_streams | ||
|
||
@asynccontextmanager | ||
async def patched_create_streams(): | ||
async with original_create_streams() as (client_streams, server_streams): | ||
client_read, client_write = client_streams | ||
server_read, server_write = server_streams | ||
|
||
# Create spy wrappers | ||
spy_client_write = SpyMemoryObjectSendStream(client_write) | ||
spy_server_write = SpyMemoryObjectSendStream(server_write) | ||
|
||
# Capture references for the test to use | ||
capture_spies(spy_client_write, spy_server_write) | ||
|
||
yield (client_read, spy_client_write), (server_read, spy_server_write) | ||
|
||
# Apply the patch for the duration of the test | ||
with patch( | ||
"mcp.shared.memory.create_client_server_memory_streams", patched_create_streams | ||
): | ||
# Return a collection with helper methods | ||
def get_spy_collection() -> StreamSpyCollection: | ||
assert client_spy is not None, "client_spy was not initialized" | ||
assert server_spy is not None, "server_spy was not initialized" | ||
return StreamSpyCollection(client_spy, server_spy) | ||
|
||
yield get_spy_collection |
Uh oh!
There was an error while loading. Please reload this page.