Skip to content

fix: Pass cursor parameter to server #745

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions src/mcp/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,9 @@ async def list_resources(
types.ClientRequest(
types.ListResourcesRequest(
method="resources/list",
cursor=cursor,
params=types.PaginatedRequestParams(cursor=cursor)
if cursor is not None
else None,
)
),
types.ListResourcesResult,
Expand All @@ -223,7 +225,9 @@ async def list_resource_templates(
types.ClientRequest(
types.ListResourceTemplatesRequest(
method="resources/templates/list",
cursor=cursor,
params=types.PaginatedRequestParams(cursor=cursor)
if cursor is not None
else None,
)
),
types.ListResourceTemplatesResult,
Expand Down Expand Up @@ -295,7 +299,9 @@ async def list_prompts(self, cursor: str | None = None) -> types.ListPromptsResu
types.ClientRequest(
types.ListPromptsRequest(
method="prompts/list",
cursor=cursor,
params=types.PaginatedRequestParams(cursor=cursor)
if cursor is not None
else None,
)
),
types.ListPromptsResult,
Expand Down Expand Up @@ -340,7 +346,9 @@ async def list_tools(self, cursor: str | None = None) -> types.ListToolsResult:
types.ClientRequest(
types.ListToolsRequest(
method="tools/list",
cursor=cursor,
params=types.PaginatedRequestParams(cursor=cursor)
if cursor is not None
else None,
)
),
types.ListToolsResult,
Expand Down
39 changes: 20 additions & 19 deletions src/mcp/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@ class Meta(BaseModel):
meta: Meta | None = Field(alias="_meta", default=None)


class PaginatedRequestParams(RequestParams):
cursor: Cursor | None = None
"""
An opaque token representing the current pagination position.
If provided, the server should return results starting after this cursor.
"""


class NotificationParams(BaseModel):
class Meta(BaseModel):
model_config = ConfigDict(extra="allow")
Expand All @@ -79,12 +87,13 @@ class Request(BaseModel, Generic[RequestParamsT, MethodT]):
model_config = ConfigDict(extra="allow")


class PaginatedRequest(Request[RequestParamsT, MethodT]):
cursor: Cursor | None = None
"""
An opaque token representing the current pagination position.
If provided, the server should return results starting after this cursor.
"""
Comment on lines -82 to -87
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed this, since it was only used for the methods in question.

class PaginatedRequest(
Request[PaginatedRequestParams | None, MethodT], Generic[MethodT]
):
"""Base class for paginated requests,
matching the schema's PaginatedRequest interface."""

params: PaginatedRequestParams | None = None


class Notification(BaseModel, Generic[NotificationParamsT, MethodT]):
Expand Down Expand Up @@ -358,13 +367,10 @@ class ProgressNotification(
params: ProgressNotificationParams


class ListResourcesRequest(
PaginatedRequest[RequestParams | None, Literal["resources/list"]]
):
class ListResourcesRequest(PaginatedRequest[Literal["resources/list"]]):
"""Sent from the client to request a list of resources the server has."""

method: Literal["resources/list"]
params: RequestParams | None = None


class Annotations(BaseModel):
Expand Down Expand Up @@ -423,12 +429,11 @@ class ListResourcesResult(PaginatedResult):


class ListResourceTemplatesRequest(
PaginatedRequest[RequestParams | None, Literal["resources/templates/list"]]
PaginatedRequest[Literal["resources/templates/list"]]
):
"""Sent from the client to request a list of resource templates the server has."""

method: Literal["resources/templates/list"]
params: RequestParams | None = None


class ListResourceTemplatesResult(PaginatedResult):
Expand Down Expand Up @@ -570,13 +575,10 @@ class ResourceUpdatedNotification(
params: ResourceUpdatedNotificationParams


class ListPromptsRequest(
PaginatedRequest[RequestParams | None, Literal["prompts/list"]]
):
class ListPromptsRequest(PaginatedRequest[Literal["prompts/list"]]):
"""Sent from the client to request a list of prompts and prompt templates."""

method: Literal["prompts/list"]
params: RequestParams | None = None


class PromptArgument(BaseModel):
Expand Down Expand Up @@ -703,11 +705,10 @@ class PromptListChangedNotification(
params: NotificationParams | None = None


class ListToolsRequest(PaginatedRequest[RequestParams | None, Literal["tools/list"]]):
class ListToolsRequest(PaginatedRequest[Literal["tools/list"]]):
"""Sent from the client to request a list of tools the server has."""

method: Literal["tools/list"]
params: RequestParams | None = None


class ToolAnnotations(BaseModel):
Expand Down Expand Up @@ -741,7 +742,7 @@ class ToolAnnotations(BaseModel):

idempotentHint: bool | None = None
"""
If true, calling the tool repeatedly with the same arguments
If true, calling the tool repeatedly with the same arguments
will have no additional effect on the its environment.
(This property is meaningful only when `readOnlyHint == false`)
Default: false
Expand Down
145 changes: 145 additions & 0 deletions tests/client/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from contextlib import asynccontextmanager
from unittest.mock import patch

import pytest

import mcp.shared.memory
from mcp.shared.message import SessionMessage
from mcp.types import (
JSONRPCNotification,
JSONRPCRequest,
)


class SpyMemoryObjectSendStream:
def __init__(self, original_stream):
self.original_stream = original_stream
self.sent_messages: list[SessionMessage] = []

async def send(self, message):
self.sent_messages.append(message)
await self.original_stream.send(message)

async def aclose(self):
await self.original_stream.aclose()

async def __aenter__(self):
return self

async def __aexit__(self, *args):
await self.aclose()


class StreamSpyCollection:
def __init__(
self,
client_spy: SpyMemoryObjectSendStream,
server_spy: SpyMemoryObjectSendStream,
):
self.client = client_spy
self.server = server_spy

def clear(self) -> None:
"""Clear all captured messages."""
self.client.sent_messages.clear()
self.server.sent_messages.clear()

def get_client_requests(self, method: str | None = None) -> list[JSONRPCRequest]:
"""Get client-sent requests, optionally filtered by method."""
return [
req.message.root
for req in self.client.sent_messages
if isinstance(req.message.root, JSONRPCRequest)
and (method is None or req.message.root.method == method)
]

def get_server_requests(self, method: str | None = None) -> list[JSONRPCRequest]:
"""Get server-sent requests, optionally filtered by method."""
return [
req.message.root
for req in self.server.sent_messages
if isinstance(req.message.root, JSONRPCRequest)
and (method is None or req.message.root.method == method)
]

def get_client_notifications(
self, method: str | None = None
) -> list[JSONRPCNotification]:
"""Get client-sent notifications, optionally filtered by method."""
return [
notif.message.root
for notif in self.client.sent_messages
if isinstance(notif.message.root, JSONRPCNotification)
and (method is None or notif.message.root.method == method)
]

def get_server_notifications(
self, method: str | None = None
) -> list[JSONRPCNotification]:
"""Get server-sent notifications, optionally filtered by method."""
return [
notif.message.root
for notif in self.server.sent_messages
if isinstance(notif.message.root, JSONRPCNotification)
and (method is None or notif.message.root.method == method)
]


@pytest.fixture
def stream_spy():
"""Fixture that provides spies for both client and server write streams.
Comment on lines +88 to +90
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Built this as a fixture so that it can be easily reused in multiple tests.

I needed a way to verify the request payload "on the wire" and this seemed like a good approach, since it did not require any modification of the actual client code (only test fixtures).

Example usage:
async def test_something(stream_spy):
# ... set up server and client ...
spies = stream_spy()
# Run some operation that sends messages
await client.some_operation()
# Check the messages
requests = spies.get_client_requests(method="some/method")
assert len(requests) == 1
# Clear for the next operation
spies.clear()
"""
client_spy = None
server_spy = None

# Store references to our spy objects
def capture_spies(c_spy, s_spy):
nonlocal client_spy, server_spy
client_spy = c_spy
server_spy = s_spy

# Create patched version of stream creation
original_create_streams = mcp.shared.memory.create_client_server_memory_streams

@asynccontextmanager
async def patched_create_streams():
async with original_create_streams() as (client_streams, server_streams):
client_read, client_write = client_streams
server_read, server_write = server_streams

# Create spy wrappers
spy_client_write = SpyMemoryObjectSendStream(client_write)
spy_server_write = SpyMemoryObjectSendStream(server_write)

# Capture references for the test to use
capture_spies(spy_client_write, spy_server_write)

yield (client_read, spy_client_write), (server_read, spy_server_write)

# Apply the patch for the duration of the test
with patch(
"mcp.shared.memory.create_client_server_memory_streams", patched_create_streams
):
# Return a collection with helper methods
def get_spy_collection() -> StreamSpyCollection:
assert client_spy is not None, "client_spy was not initialized"
assert server_spy is not None, "server_spy was not initialized"
return StreamSpyCollection(client_spy, server_spy)

yield get_spy_collection
Loading
Loading