From 163434393169d10850ff7b8b80cef90a972a14df Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Wed, 6 Nov 2024 12:24:46 +0000 Subject: [PATCH 1/5] Update types for spec changes --- mcp_python/types.py | 163 ++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 151 insertions(+), 12 deletions(-) diff --git a/mcp_python/types.py b/mcp_python/types.py index b3ab4dd98..0e26e7460 100644 --- a/mcp_python/types.py +++ b/mcp_python/types.py @@ -1,12 +1,12 @@ from typing import Any, Generic, Literal, TypeVar -from pydantic import BaseModel, ConfigDict, RootModel +from pydantic import BaseModel, ConfigDict, FileUrl, RootModel from pydantic.networks import AnyUrl """ Model Context Protocol bindings for Python -These bindings were generated from https://github.com/anthropic-experimental/mcp-spec, +These bindings were generated from https://github.com/modelcontextprotocol/specification, using Claude, with a prompt something like the following: Generate idiomatic Python bindings for this schema for MCP, or the "Model Context @@ -21,7 +21,7 @@ not separate types in the schema. """ -LATEST_PROTOCOL_VERSION = "2024-10-07" +LATEST_PROTOCOL_VERSION = "2024-11-05" ProgressToken = str | int Cursor = str @@ -191,6 +191,8 @@ class ClientCapabilities(BaseModel): """Experimental, non-standard capabilities that the client supports.""" sampling: dict[str, Any] | None = None """Present if the client supports sampling from an LLM.""" + roots: dict[str, Any] | None = None + """Present if the client supports listing roots.""" model_config = ConfigDict(extra="allow") @@ -556,12 +558,33 @@ class SamplingMessage(BaseModel): model_config = ConfigDict(extra="allow") +class EmbeddedResource(BaseModel): + """ + The contents of a resource, embedded into a prompt or tool call result. + + It is up to the client how best to render embedded resources for the benefit + of the LLM and/or the user. + """ + + type: Literal["resource"] + resource: TextResourceContents | BlobResourceContents + model_config = ConfigDict(extra="allow") + + +class PromptMessage(BaseModel): + """Describes a message returned as part of a prompt.""" + + role: Role + content: TextContent | ImageContent | EmbeddedResource + model_config = ConfigDict(extra="allow") + + class GetPromptResult(Result): """The server's response to a prompts/get request from the client.""" description: str | None = None """An optional description for the prompt.""" - messages: list[SamplingMessage] + messages: list[PromptMessage] class PromptListChangedNotification(Notification): @@ -617,7 +640,8 @@ class CallToolRequest(Request): class CallToolResult(Result): """The server's response to a tool call.""" - toolResult: Any + content: list[TextContent | ImageContent | EmbeddedResource] + isError: bool class ToolListChangedNotification(Notification): @@ -630,7 +654,7 @@ class ToolListChangedNotification(Notification): params: NotificationParams | None = None -LoggingLevel = Literal["debug", "info", "warning", "error"] +LoggingLevel = Literal["debug", "info", "notice", "warning", "error", "critical", "alert", "emergency"] class SetLevelRequestParams(RequestParams): @@ -673,10 +697,71 @@ class LoggingMessageNotification(Notification): IncludeContext = Literal["none", "thisServer", "allServers"] +class ModelHint(BaseModel): + """Hints to use for model selection.""" + + name: str | None = None + """A hint for a model name.""" + + model_config = ConfigDict(extra="allow") + + +class ModelPreferences(BaseModel): + """ + The server's preferences for model selection, requested of the client during sampling. + + Because LLMs can vary along multiple dimensions, choosing the "best" model is + rarely straightforward. Different models excel in different areas—some are + faster but less capable, others are more capable but more expensive, and so + on. This interface allows servers to express their priorities across multiple + dimensions to help clients make an appropriate selection for their use case. + + These preferences are always advisory. The client MAY ignore them. It is also + up to the client to decide how to interpret these preferences and how to + balance them against other considerations. + """ + + hints: list[ModelHint] | None = None + """ + Optional hints to use for model selection. + + If multiple hints are specified, the client MUST evaluate them in order + (such that the first match is taken). + + The client SHOULD prioritize these hints over the numeric priorities, but + MAY still use the priorities to select from ambiguous matches. + """ + + costPriority: float | None = None + """ + How much to prioritize cost when selecting a model. A value of 0 means cost + is not important, while a value of 1 means cost is the most important + factor. + """ + + speedPriority: float | None = None + """ + How much to prioritize sampling speed (latency) when selecting a model. A + value of 0 means speed is not important, while a value of 1 means speed is + the most important factor. + """ + + intelligencePriority: float | None = None + """ + How much to prioritize intelligence and capabilities when selecting a + model. A value of 0 means intelligence is not important, while a value of 1 + means intelligence is the most important factor. + """ + + model_config = ConfigDict(extra="allow") + + class CreateMessageRequestParams(RequestParams): """Parameters for creating a message.""" messages: list[SamplingMessage] + modelPreferences: ModelPreferences | None = None + """The server's preferences for which model to select. The client MAY ignore these preferences.""" systemPrompt: str | None = None """An optional system prompt the server wants to use for sampling.""" includeContext: IncludeContext | None = None @@ -700,7 +785,7 @@ class CreateMessageRequest(Request): params: CreateMessageRequestParams -StopReason = Literal["endTurn", "stopSequence", "maxTokens"] +StopReason = Literal["endTurn", "stopSequence", "maxTokens"] | str class CreateMessageResult(Result): @@ -710,8 +795,8 @@ class CreateMessageResult(Result): content: TextContent | ImageContent model: str """The name of the model that generated the message.""" - stopReason: StopReason - """The reason why sampling stopped.""" + stopReason: StopReason | None = None + """The reason why sampling stopped, if known.""" class ResourceReference(BaseModel): @@ -781,6 +866,60 @@ class CompleteResult(Result): completion: Completion +class ListRootsRequest(Request): + """ + Sent from the server to request a list of root URIs from the client. Roots allow + servers to ask for specific directories or files to operate on. A common example + for roots is providing a set of repositories or directories a server should operate + on. + + This request is typically used when the server needs to understand the file system + structure or access specific locations that the client has permission to read from. + """ + + method: Literal["roots/list"] + params: RequestParams | None = None + + +class Root(BaseModel): + """Represents a root directory or file that the server can operate on.""" + + uri: FileUrl + """ + The URI identifying the root. This *must* start with file:// for now. + This restriction may be relaxed in future versions of the protocol to allow + other URI schemes. + """ + name: str | None = None + """ + An optional name for the root. This can be used to provide a human-readable + identifier for the root, which may be useful for display purposes or for + referencing the root in other parts of the application. + """ + model_config = ConfigDict(extra="allow") + + +class ListRootsResult(Result): + """ + The client's response to a roots/list request from the server. + This result contains an array of Root objects, each representing a root directory + or file that the server can operate on. + """ + + roots: list[Root] + + +class RootsListChangedNotification(Notification): + """ + A notification from the client to the server, informing it that the list of roots has changed. + This notification should be sent whenever the client adds, removes, or modifies any root. + The server should then request an updated list of roots using the ListRootsRequest. + """ + + method: Literal["notifications/roots/list_changed"] + params: NotificationParams | None = None + + class ClientRequest( RootModel[ PingRequest @@ -801,15 +940,15 @@ class ClientRequest( pass -class ClientNotification(RootModel[ProgressNotification | InitializedNotification]): +class ClientNotification(RootModel[ProgressNotification | InitializedNotification | RootsListChangedNotification]): pass -class ClientResult(RootModel[EmptyResult | CreateMessageResult]): +class ClientResult(RootModel[EmptyResult | CreateMessageResult | ListRootsResult]): pass -class ServerRequest(RootModel[PingRequest | CreateMessageRequest]): +class ServerRequest(RootModel[PingRequest | CreateMessageRequest | ListRootsRequest]): pass From 4ac03d40f9037abe434fa45095dff874a3c34bb0 Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Wed, 6 Nov 2024 12:24:53 +0000 Subject: [PATCH 2/5] Update convenience methods on ClientSession and ServerSession --- mcp_python/client/session.py | 85 +++++++++++++++++++++++++++++++++++- mcp_python/server/session.py | 51 +++++++++++++++++++++- 2 files changed, 133 insertions(+), 3 deletions(-) diff --git a/mcp_python/client/session.py b/mcp_python/client/session.py index 82ec0b2bd..663109faf 100644 --- a/mcp_python/client/session.py +++ b/mcp_python/client/session.py @@ -1,7 +1,7 @@ from datetime import timedelta from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from pydantic import AnyUrl +from pydantic import AnyUrl, FileUrl from mcp_python.shared.session import BaseSession from mcp_python.shared.version import SUPPORTED_PROTOCOL_VERSIONS @@ -12,14 +12,21 @@ ClientNotification, ClientRequest, ClientResult, + CompleteResult, EmptyResult, + GetPromptResult, Implementation, InitializedNotification, InitializeResult, JSONRPCMessage, + ListPromptsResult, ListResourcesResult, + ListRootsResult, + ListToolsResult, LoggingLevel, + PromptReference, ReadResourceResult, + ResourceReference, ServerNotification, ServerRequest, ) @@ -61,7 +68,12 @@ async def initialize(self) -> InitializeResult: params=InitializeRequestParams( protocolVersion=LATEST_PROTOCOL_VERSION, capabilities=ClientCapabilities( - sampling=None, experimental=None + sampling=None, + experimental=None, + roots={ + # TODO: Should this be based on whether we _will_ send notifications, or only whether they're supported? + "listChanged": True + } ), clientInfo=Implementation(name="mcp_python", version="0.1.0"), ), @@ -220,3 +232,72 @@ async def call_tool( ), CallToolResult, ) + + async def list_prompts(self) -> ListPromptsResult: + """Send a prompts/list request.""" + from mcp_python.types import ListPromptsRequest + + return await self.send_request( + ClientRequest( + ListPromptsRequest( + method="prompts/list", + ) + ), + ListPromptsResult, + ) + + async def get_prompt(self, name: str, arguments: dict[str, str] | None = None) -> GetPromptResult: + """Send a prompts/get request.""" + from mcp_python.types import GetPromptRequest, GetPromptRequestParams + + return await self.send_request( + ClientRequest( + GetPromptRequest( + method="prompts/get", + params=GetPromptRequestParams(name=name, arguments=arguments), + ) + ), + GetPromptResult, + ) + + async def complete(self, ref: ResourceReference | PromptReference, argument: dict) -> CompleteResult: + """Send a completion/complete request.""" + from mcp_python.types import CompleteRequest, CompleteRequestParams, CompletionArgument + + return await self.send_request( + ClientRequest( + CompleteRequest( + method="completion/complete", + params=CompleteRequestParams( + ref=ref, + argument=CompletionArgument(**argument), + ), + ) + ), + CompleteResult, + ) + + async def list_tools(self) -> ListToolsResult: + """Send a tools/list request.""" + from mcp_python.types import ListToolsRequest + + return await self.send_request( + ClientRequest( + ListToolsRequest( + method="tools/list", + ) + ), + ListToolsResult, + ) + + async def send_roots_list_changed(self) -> None: + """Send a roots/list_changed notification.""" + from mcp_python.types import RootsListChangedNotification + + await self.send_notification( + ClientNotification( + RootsListChangedNotification( + method="notifications/roots/list_changed", + ) + ) + ) diff --git a/mcp_python/server/session.py b/mcp_python/server/session.py index be8a4df81..8e618143b 100644 --- a/mcp_python/server/session.py +++ b/mcp_python/server/session.py @@ -12,7 +12,7 @@ RequestResponder, ) from mcp_python.types import ( - LATEST_PROTOCOL_VERSION, + ListRootsResult, LATEST_PROTOCOL_VERSION, ClientNotification, ClientRequest, CreateMessageResult, @@ -28,6 +28,10 @@ ServerNotification, ServerRequest, ServerResult, + ResourceListChangedNotification, + ToolListChangedNotification, + PromptListChangedNotification, + ModelPreferences, ) @@ -142,6 +146,7 @@ async def request_create_message( temperature: float | None = None, stop_sequences: list[str] | None = None, metadata: dict[str, Any] | None = None, + model_preferences: ModelPreferences | None = None, ) -> CreateMessageResult: """Send a sampling/create_message request.""" from mcp_python.types import ( @@ -161,12 +166,26 @@ async def request_create_message( maxTokens=max_tokens, stopSequences=stop_sequences, metadata=metadata, + modelPreferences=model_preferences, ), ) ), CreateMessageResult, ) + async def list_roots(self) -> ListRootsResult: + """Send a roots/list request.""" + from mcp_python.types import ListRootsRequest + + return await self.send_request( + ServerRequest( + ListRootsRequest( + method="roots/list", + ) + ), + ListRootsResult, + ) + async def send_ping(self) -> EmptyResult: """Send a ping request.""" from mcp_python.types import PingRequest @@ -198,3 +217,33 @@ async def send_progress_notification( ) ) ) + + async def send_resource_list_changed(self) -> None: + """Send a resource list changed notification.""" + await self.send_notification( + ServerNotification( + ResourceListChangedNotification( + method="notifications/resources/list_changed", + ) + ) + ) + + async def send_tool_list_changed(self) -> None: + """Send a tool list changed notification.""" + await self.send_notification( + ServerNotification( + ToolListChangedNotification( + method="notifications/tools/list_changed", + ) + ) + ) + + async def send_prompt_list_changed(self) -> None: + """Send a prompt list changed notification.""" + await self.send_notification( + ServerNotification( + PromptListChangedNotification( + method="notifications/prompts/list_changed", + ) + ) + ) From 185a18621526a1521d33a5b1e59fe72efe9447cf Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Wed, 6 Nov 2024 12:25:09 +0000 Subject: [PATCH 3/5] Rename request_create_message for consistency --- mcp_python/server/session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mcp_python/server/session.py b/mcp_python/server/session.py index 8e618143b..7ecdf1915 100644 --- a/mcp_python/server/session.py +++ b/mcp_python/server/session.py @@ -136,7 +136,7 @@ async def send_resource_updated(self, uri: AnyUrl) -> None: ) ) - async def request_create_message( + async def create_message( self, messages: list[SamplingMessage], *, From a891ad4689d90e086c27873c39ce2d2e0f3c5607 Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Wed, 6 Nov 2024 12:33:50 +0000 Subject: [PATCH 4/5] Update tool calls to use structured results --- mcp_python/server/__init__.py | 56 +++++++++++++++++++++++++++++------ mcp_python/server/types.py | 11 +++++-- 2 files changed, 55 insertions(+), 12 deletions(-) diff --git a/mcp_python/server/__init__.py b/mcp_python/server/__init__.py index ff8900fc7..5c85eeaa7 100644 --- a/mcp_python/server/__init__.py +++ b/mcp_python/server/__init__.py @@ -42,6 +42,9 @@ SubscribeRequest, Tool, UnsubscribeRequest, + TextContent, + EmbeddedResource, + PromptMessage, ) logger = logging.getLogger(__name__) @@ -117,8 +120,6 @@ def get_prompt(self): GetPromptRequest, GetPromptResult, ImageContent, - SamplingMessage, - TextContent, ) from mcp_python.types import ( Role as Role, @@ -133,7 +134,7 @@ def decorator( async def handler(req: GetPromptRequest): prompt_get = await func(req.params.name, req.params.arguments) - messages: list[SamplingMessage] = [] + messages: list[PromptMessage] = [] for message in prompt_get.messages: match message.content: case str() as text_content: @@ -144,15 +145,20 @@ async def handler(req: GetPromptRequest): data=img_content.data, mimeType=img_content.mime_type, ) + case types.EmbeddedResource() as resource: + content = EmbeddedResource( + type="resource", + resource=resource.resource + ) case _: raise ValueError( f"Unexpected content type: {type(message.content)}" ) - sampling_message = SamplingMessage( + prompt_message = PromptMessage( role=message.role, content=content ) - messages.append(sampling_message) + messages.append(prompt_message) return ServerResult( GetPromptResult(description=prompt_get.desc, messages=messages) @@ -276,14 +282,46 @@ async def handler(_: Any): return decorator def call_tool(self): - from mcp_python.types import CallToolResult + from mcp_python.types import CallToolResult, TextContent, ImageContent, EmbeddedResource - def decorator(func: Callable[..., Awaitable[Any]]): + def decorator( + func: Callable[..., Awaitable[list[str | types.ImageContent | types.EmbeddedResource]]] + ): logger.debug("Registering handler for CallToolRequest") async def handler(req: CallToolRequest): - result = await func(req.params.name, (req.params.arguments or {})) - return ServerResult(CallToolResult(toolResult=result)) + try: + results = await func(req.params.name, (req.params.arguments or {})) + content = [] + for result in results: + match result: + case str() as text: + content.append(TextContent(type="text", text=text)) + case types.ImageContent() as img: + content.append(ImageContent( + type="image", + data=img.data, + mimeType=img.mime_type + )) + case types.EmbeddedResource() as resource: + content.append(EmbeddedResource( + type="resource", + resource=resource.resource + )) + + return ServerResult( + CallToolResult( + content=content, + isError=False + ) + ) + except Exception as e: + return ServerResult( + CallToolResult( + content=[TextContent(type="text", text=str(e))], + isError=True + ) + ) self.request_handlers[CallToolRequest] = handler return func diff --git a/mcp_python/server/types.py b/mcp_python/server/types.py index 76324060e..437bc2948 100644 --- a/mcp_python/server/types.py +++ b/mcp_python/server/types.py @@ -1,5 +1,5 @@ """ -This module provides simpler types to use with the server for managing prompts. +This module provides simpler types to use with the server for managing prompts and tools. """ from dataclasses import dataclass @@ -7,7 +7,7 @@ from pydantic import BaseModel -from mcp_python.types import Role, ServerCapabilities +from mcp_python.types import Role, ServerCapabilities, TextResourceContents, BlobResourceContents @dataclass @@ -17,10 +17,15 @@ class ImageContent: mime_type: str +@dataclass +class EmbeddedResource: + resource: TextResourceContents | BlobResourceContents + + @dataclass class Message: role: Role - content: str | ImageContent + content: str | ImageContent | EmbeddedResource @dataclass From c7d8f11e0c5aa51ed1ed3890ca2bc995090bb731 Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Wed, 6 Nov 2024 12:35:32 +0000 Subject: [PATCH 5/5] Formatting --- mcp_python/client/session.py | 23 +++++++++----- mcp_python/server/__init__.py | 59 +++++++++++++++++------------------ mcp_python/server/session.py | 9 +++--- mcp_python/server/types.py | 10 ++++-- mcp_python/shared/memory.py | 10 +++--- mcp_python/shared/session.py | 4 +-- mcp_python/types.py | 27 +++++++++++----- tests/conftest.py | 3 +- 8 files changed, 87 insertions(+), 58 deletions(-) diff --git a/mcp_python/client/session.py b/mcp_python/client/session.py index 663109faf..6c6d01fb3 100644 --- a/mcp_python/client/session.py +++ b/mcp_python/client/session.py @@ -1,7 +1,7 @@ from datetime import timedelta from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from pydantic import AnyUrl, FileUrl +from pydantic import AnyUrl from mcp_python.shared.session import BaseSession from mcp_python.shared.version import SUPPORTED_PROTOCOL_VERSIONS @@ -21,7 +21,6 @@ JSONRPCMessage, ListPromptsResult, ListResourcesResult, - ListRootsResult, ListToolsResult, LoggingLevel, PromptReference, @@ -71,9 +70,11 @@ async def initialize(self) -> InitializeResult: sampling=None, experimental=None, roots={ - # TODO: Should this be based on whether we _will_ send notifications, or only whether they're supported? + # TODO: Should this be based on whether we + # _will_ send notifications, or only whether + # they're supported? "listChanged": True - } + }, ), clientInfo=Implementation(name="mcp_python", version="0.1.0"), ), @@ -246,7 +247,9 @@ async def list_prompts(self) -> ListPromptsResult: ListPromptsResult, ) - async def get_prompt(self, name: str, arguments: dict[str, str] | None = None) -> GetPromptResult: + async def get_prompt( + self, name: str, arguments: dict[str, str] | None = None + ) -> GetPromptResult: """Send a prompts/get request.""" from mcp_python.types import GetPromptRequest, GetPromptRequestParams @@ -260,9 +263,15 @@ async def get_prompt(self, name: str, arguments: dict[str, str] | None = None) - GetPromptResult, ) - async def complete(self, ref: ResourceReference | PromptReference, argument: dict) -> CompleteResult: + async def complete( + self, ref: ResourceReference | PromptReference, argument: dict + ) -> CompleteResult: """Send a completion/complete request.""" - from mcp_python.types import CompleteRequest, CompleteRequestParams, CompletionArgument + from mcp_python.types import ( + CompleteRequest, + CompleteRequestParams, + CompletionArgument, + ) return await self.send_request( ClientRequest( diff --git a/mcp_python/server/__init__.py b/mcp_python/server/__init__.py index 5c85eeaa7..26d690250 100644 --- a/mcp_python/server/__init__.py +++ b/mcp_python/server/__init__.py @@ -18,6 +18,7 @@ ClientNotification, ClientRequest, CompleteRequest, + EmbeddedResource, EmptyResult, ErrorData, JSONRPCMessage, @@ -31,6 +32,7 @@ PingRequest, ProgressNotification, Prompt, + PromptMessage, PromptReference, ReadResourceRequest, ReadResourceResult, @@ -40,11 +42,9 @@ ServerResult, SetLevelRequest, SubscribeRequest, + TextContent, Tool, UnsubscribeRequest, - TextContent, - EmbeddedResource, - PromptMessage, ) logger = logging.getLogger(__name__) @@ -147,17 +147,14 @@ async def handler(req: GetPromptRequest): ) case types.EmbeddedResource() as resource: content = EmbeddedResource( - type="resource", - resource=resource.resource + type="resource", resource=resource.resource ) case _: raise ValueError( f"Unexpected content type: {type(message.content)}" ) - prompt_message = PromptMessage( - role=message.role, content=content - ) + prompt_message = PromptMessage(role=message.role, content=content) messages.append(prompt_message) return ServerResult( @@ -175,9 +172,7 @@ def decorator(func: Callable[[], Awaitable[list[Resource]]]): async def handler(_: Any): resources = await func() - return ServerResult( - ListResourcesResult(resources=resources) - ) + return ServerResult(ListResourcesResult(resources=resources)) self.request_handlers[ListResourcesRequest] = handler return func @@ -222,7 +217,6 @@ async def handler(req: ReadResourceRequest): return decorator - def set_logging_level(self): from mcp_python.types import EmptyResult @@ -282,10 +276,17 @@ async def handler(_: Any): return decorator def call_tool(self): - from mcp_python.types import CallToolResult, TextContent, ImageContent, EmbeddedResource + from mcp_python.types import ( + CallToolResult, + EmbeddedResource, + ImageContent, + TextContent, + ) def decorator( - func: Callable[..., Awaitable[list[str | types.ImageContent | types.EmbeddedResource]]] + func: Callable[ + ..., Awaitable[list[str | types.ImageContent | types.EmbeddedResource]] + ], ): logger.debug("Registering handler for CallToolRequest") @@ -298,28 +299,26 @@ async def handler(req: CallToolRequest): case str() as text: content.append(TextContent(type="text", text=text)) case types.ImageContent() as img: - content.append(ImageContent( - type="image", - data=img.data, - mimeType=img.mime_type - )) + content.append( + ImageContent( + type="image", + data=img.data, + mimeType=img.mime_type, + ) + ) case types.EmbeddedResource() as resource: - content.append(EmbeddedResource( - type="resource", - resource=resource.resource - )) + content.append( + EmbeddedResource( + type="resource", resource=resource.resource + ) + ) - return ServerResult( - CallToolResult( - content=content, - isError=False - ) - ) + return ServerResult(CallToolResult(content=content, isError=False)) except Exception as e: return ServerResult( CallToolResult( content=[TextContent(type="text", text=str(e))], - isError=True + isError=True, ) ) diff --git a/mcp_python/server/session.py b/mcp_python/server/session.py index 7ecdf1915..db7ebc2c2 100644 --- a/mcp_python/server/session.py +++ b/mcp_python/server/session.py @@ -12,7 +12,7 @@ RequestResponder, ) from mcp_python.types import ( - ListRootsResult, LATEST_PROTOCOL_VERSION, + LATEST_PROTOCOL_VERSION, ClientNotification, ClientRequest, CreateMessageResult, @@ -23,15 +23,16 @@ InitializeRequest, InitializeResult, JSONRPCMessage, + ListRootsResult, LoggingLevel, + ModelPreferences, + PromptListChangedNotification, + ResourceListChangedNotification, SamplingMessage, ServerNotification, ServerRequest, ServerResult, - ResourceListChangedNotification, ToolListChangedNotification, - PromptListChangedNotification, - ModelPreferences, ) diff --git a/mcp_python/server/types.py b/mcp_python/server/types.py index 437bc2948..acc5c1ea3 100644 --- a/mcp_python/server/types.py +++ b/mcp_python/server/types.py @@ -1,5 +1,6 @@ """ -This module provides simpler types to use with the server for managing prompts and tools. +This module provides simpler types to use with the server for managing prompts +and tools. """ from dataclasses import dataclass @@ -7,7 +8,12 @@ from pydantic import BaseModel -from mcp_python.types import Role, ServerCapabilities, TextResourceContents, BlobResourceContents +from mcp_python.types import ( + BlobResourceContents, + Role, + ServerCapabilities, + TextResourceContents, +) @dataclass diff --git a/mcp_python/shared/memory.py b/mcp_python/shared/memory.py index a2917499a..6ebfe9f3b 100644 --- a/mcp_python/shared/memory.py +++ b/mcp_python/shared/memory.py @@ -15,14 +15,14 @@ MessageStream = tuple[ MemoryObjectReceiveStream[JSONRPCMessage | Exception], - MemoryObjectSendStream[JSONRPCMessage] + MemoryObjectSendStream[JSONRPCMessage], ] + @asynccontextmanager -async def create_client_server_memory_streams() -> AsyncGenerator[ - tuple[MessageStream, MessageStream], - None -]: +async def create_client_server_memory_streams() -> ( + AsyncGenerator[tuple[MessageStream, MessageStream], None] +): """ Creates a pair of bidirectional memory streams for client-server communication. diff --git a/mcp_python/shared/session.py b/mcp_python/shared/session.py index f063a33bd..95e354b39 100644 --- a/mcp_python/shared/session.py +++ b/mcp_python/shared/session.py @@ -154,7 +154,8 @@ async def send_request( try: with anyio.fail_after( - None if self._read_timeout_seconds is None + None + if self._read_timeout_seconds is None else self._read_timeout_seconds.total_seconds() ): response_or_error = await response_stream_reader.receive() @@ -168,7 +169,6 @@ async def send_request( f"{self._read_timeout_seconds} seconds." ), ) - ) if isinstance(response_or_error, JSONRPCError): diff --git a/mcp_python/types.py b/mcp_python/types.py index 0e26e7460..4e8071918 100644 --- a/mcp_python/types.py +++ b/mcp_python/types.py @@ -654,7 +654,9 @@ class ToolListChangedNotification(Notification): params: NotificationParams | None = None -LoggingLevel = Literal["debug", "info", "notice", "warning", "error", "critical", "alert", "emergency"] +LoggingLevel = Literal[ + "debug", "info", "notice", "warning", "error", "critical", "alert", "emergency" +] class SetLevelRequestParams(RequestParams): @@ -708,7 +710,8 @@ class ModelHint(BaseModel): class ModelPreferences(BaseModel): """ - The server's preferences for model selection, requested of the client during sampling. + The server's preferences for model selection, requested of the client during + sampling. Because LLMs can vary along multiple dimensions, choosing the "best" model is rarely straightforward. Different models excel in different areas—some are @@ -761,7 +764,10 @@ class CreateMessageRequestParams(RequestParams): messages: list[SamplingMessage] modelPreferences: ModelPreferences | None = None - """The server's preferences for which model to select. The client MAY ignore these preferences.""" + """ + The server's preferences for which model to select. The client MAY ignore + these preferences. + """ systemPrompt: str | None = None """An optional system prompt the server wants to use for sampling.""" includeContext: IncludeContext | None = None @@ -911,9 +917,12 @@ class ListRootsResult(Result): class RootsListChangedNotification(Notification): """ - A notification from the client to the server, informing it that the list of roots has changed. - This notification should be sent whenever the client adds, removes, or modifies any root. - The server should then request an updated list of roots using the ListRootsRequest. + A notification from the client to the server, informing it that the list of + roots has changed. + + This notification should be sent whenever the client adds, removes, or + modifies any root. The server should then request an updated list of roots + using the ListRootsRequest. """ method: Literal["notifications/roots/list_changed"] @@ -940,7 +949,11 @@ class ClientRequest( pass -class ClientNotification(RootModel[ProgressNotification | InitializedNotification | RootsListChangedNotification]): +class ClientNotification( + RootModel[ + ProgressNotification | InitializedNotification | RootsListChangedNotification + ] +): pass diff --git a/tests/conftest.py b/tests/conftest.py index 37ff5a4ec..28690b249 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,6 +11,7 @@ capabilities=ServerCapabilities(), ) + @pytest.fixture def mcp_server() -> Server: server = Server(name="test_server") @@ -21,7 +22,7 @@ async def handle_list_resources(): Resource( uri=AnyUrl("memory://test"), name="Test Resource", - description="A test resource" + description="A test resource", ) ]