Skip to content

Update to spec version 2024-11-05 #28

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 91 additions & 1 deletion mcp_python/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,20 @@
ClientNotification,
ClientRequest,
ClientResult,
CompleteResult,
EmptyResult,
GetPromptResult,
Implementation,
InitializedNotification,
InitializeResult,
JSONRPCMessage,
ListPromptsResult,
ListResourcesResult,
ListToolsResult,
LoggingLevel,
PromptReference,
ReadResourceResult,
ResourceReference,
ServerNotification,
ServerRequest,
)
Expand Down Expand Up @@ -61,7 +67,14 @@ async def initialize(self) -> InitializeResult:
params=InitializeRequestParams(
protocolVersion=LATEST_PROTOCOL_VERSION,
capabilities=ClientCapabilities(
sampling=None, experimental=None
sampling=None,
experimental=None,
roots={
# TODO: Should this be based on whether we
# _will_ send notifications, or only whether
# they're supported?
"listChanged": True
},
),
clientInfo=Implementation(name="mcp_python", version="0.1.0"),
),
Expand Down Expand Up @@ -220,3 +233,80 @@ async def call_tool(
),
CallToolResult,
)

async def list_prompts(self) -> ListPromptsResult:
"""Send a prompts/list request."""
from mcp_python.types import ListPromptsRequest

return await self.send_request(
ClientRequest(
ListPromptsRequest(
method="prompts/list",
)
),
ListPromptsResult,
)

async def get_prompt(
self, name: str, arguments: dict[str, str] | None = None
) -> GetPromptResult:
"""Send a prompts/get request."""
from mcp_python.types import GetPromptRequest, GetPromptRequestParams

return await self.send_request(
ClientRequest(
GetPromptRequest(
method="prompts/get",
params=GetPromptRequestParams(name=name, arguments=arguments),
)
),
GetPromptResult,
)

async def complete(
self, ref: ResourceReference | PromptReference, argument: dict
) -> CompleteResult:
"""Send a completion/complete request."""
from mcp_python.types import (
CompleteRequest,
CompleteRequestParams,
CompletionArgument,
)

return await self.send_request(
ClientRequest(
CompleteRequest(
method="completion/complete",
params=CompleteRequestParams(
ref=ref,
argument=CompletionArgument(**argument),
),
)
),
CompleteResult,
)

async def list_tools(self) -> ListToolsResult:
"""Send a tools/list request."""
from mcp_python.types import ListToolsRequest

return await self.send_request(
ClientRequest(
ListToolsRequest(
method="tools/list",
)
),
ListToolsResult,
)

async def send_roots_list_changed(self) -> None:
"""Send a roots/list_changed notification."""
from mcp_python.types import RootsListChangedNotification

await self.send_notification(
ClientNotification(
RootsListChangedNotification(
method="notifications/roots/list_changed",
)
)
)
67 changes: 52 additions & 15 deletions mcp_python/server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
ClientNotification,
ClientRequest,
CompleteRequest,
EmbeddedResource,
EmptyResult,
ErrorData,
JSONRPCMessage,
Expand All @@ -31,6 +32,7 @@
PingRequest,
ProgressNotification,
Prompt,
PromptMessage,
PromptReference,
ReadResourceRequest,
ReadResourceResult,
Expand All @@ -40,6 +42,7 @@
ServerResult,
SetLevelRequest,
SubscribeRequest,
TextContent,
Tool,
UnsubscribeRequest,
)
Expand Down Expand Up @@ -117,8 +120,6 @@ def get_prompt(self):
GetPromptRequest,
GetPromptResult,
ImageContent,
SamplingMessage,
TextContent,
)
from mcp_python.types import (
Role as Role,
Expand All @@ -133,7 +134,7 @@ def decorator(

async def handler(req: GetPromptRequest):
prompt_get = await func(req.params.name, req.params.arguments)
messages: list[SamplingMessage] = []
messages: list[PromptMessage] = []
for message in prompt_get.messages:
match message.content:
case str() as text_content:
Expand All @@ -144,15 +145,17 @@ async def handler(req: GetPromptRequest):
data=img_content.data,
mimeType=img_content.mime_type,
)
case types.EmbeddedResource() as resource:
content = EmbeddedResource(
type="resource", resource=resource.resource
)
case _:
raise ValueError(
f"Unexpected content type: {type(message.content)}"
)

sampling_message = SamplingMessage(
role=message.role, content=content
)
messages.append(sampling_message)
prompt_message = PromptMessage(role=message.role, content=content)
messages.append(prompt_message)

return ServerResult(
GetPromptResult(description=prompt_get.desc, messages=messages)
Expand All @@ -169,9 +172,7 @@ def decorator(func: Callable[[], Awaitable[list[Resource]]]):

async def handler(_: Any):
resources = await func()
return ServerResult(
ListResourcesResult(resources=resources)
)
return ServerResult(ListResourcesResult(resources=resources))

self.request_handlers[ListResourcesRequest] = handler
return func
Expand Down Expand Up @@ -216,7 +217,6 @@ async def handler(req: ReadResourceRequest):

return decorator


def set_logging_level(self):
from mcp_python.types import EmptyResult

Expand Down Expand Up @@ -276,14 +276,51 @@ async def handler(_: Any):
return decorator

def call_tool(self):
from mcp_python.types import CallToolResult
from mcp_python.types import (
CallToolResult,
EmbeddedResource,
ImageContent,
TextContent,
)

def decorator(func: Callable[..., Awaitable[Any]]):
def decorator(
func: Callable[
..., Awaitable[list[str | types.ImageContent | types.EmbeddedResource]]
],
):
logger.debug("Registering handler for CallToolRequest")

async def handler(req: CallToolRequest):
result = await func(req.params.name, (req.params.arguments or {}))
return ServerResult(CallToolResult(toolResult=result))
try:
results = await func(req.params.name, (req.params.arguments or {}))
content = []
for result in results:
match result:
case str() as text:
content.append(TextContent(type="text", text=text))
case types.ImageContent() as img:
content.append(
ImageContent(
type="image",
data=img.data,
mimeType=img.mime_type,
)
)
case types.EmbeddedResource() as resource:
content.append(
EmbeddedResource(
type="resource", resource=resource.resource
)
)

return ServerResult(CallToolResult(content=content, isError=False))
except Exception as e:
return ServerResult(
CallToolResult(
content=[TextContent(type="text", text=str(e))],
isError=True,
)
)

self.request_handlers[CallToolRequest] = handler
return func
Expand Down
52 changes: 51 additions & 1 deletion mcp_python/server/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,16 @@
InitializeRequest,
InitializeResult,
JSONRPCMessage,
ListRootsResult,
LoggingLevel,
ModelPreferences,
PromptListChangedNotification,
ResourceListChangedNotification,
SamplingMessage,
ServerNotification,
ServerRequest,
ServerResult,
ToolListChangedNotification,
)


Expand Down Expand Up @@ -132,7 +137,7 @@ async def send_resource_updated(self, uri: AnyUrl) -> None:
)
)

async def request_create_message(
async def create_message(
self,
messages: list[SamplingMessage],
*,
Expand All @@ -142,6 +147,7 @@ async def request_create_message(
temperature: float | None = None,
stop_sequences: list[str] | None = None,
metadata: dict[str, Any] | None = None,
model_preferences: ModelPreferences | None = None,
) -> CreateMessageResult:
"""Send a sampling/create_message request."""
from mcp_python.types import (
Expand All @@ -161,12 +167,26 @@ async def request_create_message(
maxTokens=max_tokens,
stopSequences=stop_sequences,
metadata=metadata,
modelPreferences=model_preferences,
),
)
),
CreateMessageResult,
)

async def list_roots(self) -> ListRootsResult:
"""Send a roots/list request."""
from mcp_python.types import ListRootsRequest

return await self.send_request(
ServerRequest(
ListRootsRequest(
method="roots/list",
)
),
ListRootsResult,
)

async def send_ping(self) -> EmptyResult:
"""Send a ping request."""
from mcp_python.types import PingRequest
Expand Down Expand Up @@ -198,3 +218,33 @@ async def send_progress_notification(
)
)
)

async def send_resource_list_changed(self) -> None:
"""Send a resource list changed notification."""
await self.send_notification(
ServerNotification(
ResourceListChangedNotification(
method="notifications/resources/list_changed",
)
)
)

async def send_tool_list_changed(self) -> None:
"""Send a tool list changed notification."""
await self.send_notification(
ServerNotification(
ToolListChangedNotification(
method="notifications/tools/list_changed",
)
)
)

async def send_prompt_list_changed(self) -> None:
"""Send a prompt list changed notification."""
await self.send_notification(
ServerNotification(
PromptListChangedNotification(
method="notifications/prompts/list_changed",
)
)
)
17 changes: 14 additions & 3 deletions mcp_python/server/types.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
"""
This module provides simpler types to use with the server for managing prompts.
This module provides simpler types to use with the server for managing prompts
and tools.
"""

from dataclasses import dataclass
from typing import Literal

from pydantic import BaseModel

from mcp_python.types import Role, ServerCapabilities
from mcp_python.types import (
BlobResourceContents,
Role,
ServerCapabilities,
TextResourceContents,
)


@dataclass
Expand All @@ -17,10 +23,15 @@ class ImageContent:
mime_type: str


@dataclass
class EmbeddedResource:
resource: TextResourceContents | BlobResourceContents


@dataclass
class Message:
role: Role
content: str | ImageContent
content: str | ImageContent | EmbeddedResource


@dataclass
Expand Down
Loading