Skip to content

Commit a1afc21

Browse files
committed
Allow passing initialization options to a session
We need a way for servers to pass initialization options to the session. This is the beginning of this.
1 parent ac6064b commit a1afc21

File tree

5 files changed

+89
-18
lines changed

5 files changed

+89
-18
lines changed

mcp_python/server/__init__.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
import logging
33
import warnings
44
from collections.abc import Awaitable, Callable
5-
from typing import Any
5+
from typing import Any, Self
66

77
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
88
from pydantic import AnyUrl
99

1010
from mcp_python.server import types
11-
from mcp_python.server.session import ServerSession
11+
from mcp_python.server.session import ServerSession, SessionInitializationOptions
1212
from mcp_python.server.stdio import stdio_server as stdio_server
1313
from mcp_python.shared.context import RequestContext
1414
from mcp_python.shared.session import RequestResponder
@@ -32,6 +32,7 @@
3232
ReadResourceResult,
3333
Resource,
3434
ResourceReference,
35+
ServerCapabilities,
3536
ServerResult,
3637
SetLevelRequest,
3738
SubscribeRequest,
@@ -40,11 +41,27 @@
4041

4142
logger = logging.getLogger(__name__)
4243

43-
4444
request_ctx: contextvars.ContextVar[RequestContext] = contextvars.ContextVar(
4545
"request_ctx"
4646
)
4747

48+
def pkg_version(package: str) -> str:
49+
try:
50+
from importlib.metadata import version
51+
return version(package)
52+
except Exception:
53+
return "unknown"
54+
55+
class InitializationOptions(SessionInitializationOptions):
56+
"""Information about a server provided as initialization options when a new session is started."""
57+
58+
@classmethod
59+
def from_server(cls, server: "Server") -> Self:
60+
return cls(
61+
server_name=server.name,
62+
server_version=pkg_version("mcp_python"),
63+
capabilities=server.get_capabilities()
64+
)
4865

4966
class Server:
5067
def __init__(self, name: str):
@@ -276,13 +293,26 @@ async def handler(req: CompleteRequest):
276293

277294
return decorator
278295

296+
def get_capabilities(self) -> ServerCapabilities:
297+
"""Convert existing handlers to a ServerCapabilities object."""
298+
def get_capability(req_type: type) -> dict[str, Any] | None:
299+
return {} if req_type in self.request_handlers else None
300+
301+
return ServerCapabilities(
302+
prompts=get_capability(ListPromptsRequest),
303+
resources=get_capability(ListResourcesRequest),
304+
tools=get_capability(ListPromptsRequest),
305+
logging=get_capability(SetLevelRequest)
306+
)
307+
279308
async def run(
280309
self,
281310
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
282311
write_stream: MemoryObjectSendStream[JSONRPCMessage],
312+
initialization_options: InitializationOptions
283313
):
284314
with warnings.catch_warnings(record=True) as w:
285-
async with ServerSession(read_stream, write_stream) as session:
315+
async with ServerSession(read_stream, write_stream, initialization_options) as session:
286316
async for message in session.incoming_messages:
287317
logger.debug(f"Received message: {message}")
288318

mcp_python/server/__main__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import logging
22
import sys
3-
3+
import importlib.metadata
44
import anyio
55

6-
from mcp_python.server.session import ServerSession
6+
from mcp_python.server.session import SessionInitializationOptions, ServerSession
77
from mcp_python.server.stdio import stdio_server
8+
from mcp_python.types import ServerCapabilities
89

910
if not sys.warnoptions:
1011
import warnings
@@ -26,8 +27,9 @@ async def receive_loop(session: ServerSession):
2627

2728

2829
async def main():
30+
version = importlib.metadata.version("mcp_python")
2931
async with stdio_server() as (read_stream, write_stream):
30-
async with ServerSession(read_stream, write_stream) as session, write_stream:
32+
async with ServerSession(read_stream, write_stream, SessionInitializationOptions(server_name="mcp_python", server_version=version, capabilities=ServerCapabilities())) as session, write_stream:
3133
await receive_loop(session)
3234

3335

mcp_python/server/session.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import anyio
55
import anyio.lowlevel
66
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
7-
from pydantic import AnyUrl
7+
from pydantic import AnyUrl, BaseModel
88

99
from mcp_python.shared.session import (
1010
BaseSession,
@@ -37,6 +37,13 @@ class InitializationState(Enum):
3737
Initialized = 3
3838

3939

40+
class SessionInitializationOptions(BaseModel):
41+
server_name: str
42+
server_version: str
43+
capabilities: ServerCapabilities
44+
45+
46+
4047
class ServerSession(
4148
BaseSession[
4249
ServerRequest,
@@ -52,9 +59,11 @@ def __init__(
5259
self,
5360
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
5461
write_stream: MemoryObjectSendStream[JSONRPCMessage],
62+
init_options: SessionInitializationOptions
5563
) -> None:
5664
super().__init__(read_stream, write_stream, ClientRequest, ClientNotification)
5765
self._initialization_state = InitializationState.NotInitialized
66+
self._init_options = init_options
5867

5968
async def _received_request(
6069
self, responder: RequestResponder[ClientRequest, ServerResult]
@@ -66,15 +75,10 @@ async def _received_request(
6675
ServerResult(
6776
InitializeResult(
6877
protocolVersion=SUPPORTED_PROTOCOL_VERSION,
69-
capabilities=ServerCapabilities(
70-
logging=None,
71-
resources=None,
72-
tools=None,
73-
experimental=None,
74-
prompts={},
75-
),
78+
capabilities=self._init_options.capabilities,
7679
serverInfo=Implementation(
77-
name="mcp_python", version="0.1.0"
80+
name=self._init_options.server_name,
81+
version=self._init_options.server_version
7882
),
7983
)
8084
)

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,8 @@ target-version = "py38"
3535

3636
[tool.ruff.lint.per-file-ignores]
3737
"__init__.py" = ["F401"]
38+
39+
[tool.uv]
40+
dev-dependencies = [
41+
"trio>=0.26.2",
42+
]

tests/server/test_session.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22
import pytest
33

44
from mcp_python.client.session import ClientSession
5-
from mcp_python.server.session import ServerSession
5+
from mcp_python.server import Server
6+
from mcp_python.server.session import ServerSession, SessionInitializationOptions
67
from mcp_python.types import (
78
ClientNotification,
89
InitializedNotification,
910
JSONRPCMessage,
11+
ServerCapabilities,
1012
)
1113

1214

@@ -30,7 +32,7 @@ async def run_server():
3032
nonlocal received_initialized
3133

3234
async with ServerSession(
33-
client_to_server_receive, server_to_client_send
35+
client_to_server_receive, server_to_client_send, SessionInitializationOptions(server_name='mcp_python', server_version='0.1.0', capabilities=ServerCapabilities())
3436
) as server_session:
3537
async for message in server_session.incoming_messages:
3638
if isinstance(message, Exception):
@@ -57,3 +59,31 @@ async def run_server():
5759
pass
5860

5961
assert received_initialized
62+
63+
64+
@pytest.mark.anyio
65+
async def test_server_capabilities():
66+
server = Server("test")
67+
68+
# Initially no capabilities
69+
caps = server.get_capabilities()
70+
assert caps.prompts is None
71+
assert caps.resources is None
72+
73+
# Add a prompts handler
74+
@server.list_prompts()
75+
async def list_prompts():
76+
return []
77+
78+
caps = server.get_capabilities()
79+
assert caps.prompts == {}
80+
assert caps.resources is None
81+
82+
# Add a resources handler
83+
@server.list_resources()
84+
async def list_resources():
85+
return []
86+
87+
caps = server.get_capabilities()
88+
assert caps.prompts == {}
89+
assert caps.resources == {}

0 commit comments

Comments
 (0)