Skip to content

Commit 99c402d

Browse files
authored
Merge pull request #42 from modelcontextprotocol/davidsp/types
Types Rework
2 parents 837309c + ec8c85e commit 99c402d

File tree

14 files changed

+279
-503
lines changed

14 files changed

+279
-503
lines changed

src/mcp/client/session.py

Lines changed: 79 additions & 158 deletions
Original file line numberDiff line numberDiff line change
@@ -3,85 +3,56 @@
33
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
44
from pydantic import AnyUrl
55

6+
import mcp.types as types
67
from mcp.shared.session import BaseSession
78
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
8-
from mcp.types import (
9-
LATEST_PROTOCOL_VERSION,
10-
CallToolResult,
11-
ClientCapabilities,
12-
ClientNotification,
13-
ClientRequest,
14-
ClientResult,
15-
CompleteResult,
16-
EmptyResult,
17-
GetPromptResult,
18-
Implementation,
19-
InitializedNotification,
20-
InitializeResult,
21-
JSONRPCMessage,
22-
ListPromptsResult,
23-
ListResourcesResult,
24-
ListToolsResult,
25-
LoggingLevel,
26-
PromptReference,
27-
ReadResourceResult,
28-
ResourceReference,
29-
RootsCapability,
30-
ServerNotification,
31-
ServerRequest,
32-
)
339

3410

3511
class ClientSession(
3612
BaseSession[
37-
ClientRequest,
38-
ClientNotification,
39-
ClientResult,
40-
ServerRequest,
41-
ServerNotification,
13+
types.ClientRequest,
14+
types.ClientNotification,
15+
types.ClientResult,
16+
types.ServerRequest,
17+
types.ServerNotification,
4218
]
4319
):
4420
def __init__(
4521
self,
46-
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
47-
write_stream: MemoryObjectSendStream[JSONRPCMessage],
22+
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
23+
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
4824
read_timeout_seconds: timedelta | None = None,
4925
) -> None:
5026
super().__init__(
5127
read_stream,
5228
write_stream,
53-
ServerRequest,
54-
ServerNotification,
29+
types.ServerRequest,
30+
types.ServerNotification,
5531
read_timeout_seconds=read_timeout_seconds,
5632
)
5733

58-
async def initialize(self) -> InitializeResult:
59-
from mcp.types import (
60-
InitializeRequest,
61-
InitializeRequestParams,
62-
)
63-
34+
async def initialize(self) -> types.InitializeResult:
6435
result = await self.send_request(
65-
ClientRequest(
66-
InitializeRequest(
36+
types.ClientRequest(
37+
types.InitializeRequest(
6738
method="initialize",
68-
params=InitializeRequestParams(
69-
protocolVersion=LATEST_PROTOCOL_VERSION,
70-
capabilities=ClientCapabilities(
39+
params=types.InitializeRequestParams(
40+
protocolVersion=types.LATEST_PROTOCOL_VERSION,
41+
capabilities=types.ClientCapabilities(
7142
sampling=None,
7243
experimental=None,
73-
roots=RootsCapability(
44+
roots=types.RootsCapability(
7445
# TODO: Should this be based on whether we
7546
# _will_ send notifications, or only whether
7647
# they're supported?
7748
listChanged=True
7849
),
7950
),
80-
clientInfo=Implementation(name="mcp", version="0.1.0"),
51+
clientInfo=types.Implementation(name="mcp", version="0.1.0"),
8152
),
8253
)
8354
),
84-
InitializeResult,
55+
types.InitializeResult,
8556
)
8657

8758
if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS:
@@ -91,40 +62,33 @@ async def initialize(self) -> InitializeResult:
9162
)
9263

9364
await self.send_notification(
94-
ClientNotification(
95-
InitializedNotification(method="notifications/initialized")
65+
types.ClientNotification(
66+
types.InitializedNotification(method="notifications/initialized")
9667
)
9768
)
9869

9970
return result
10071

101-
async def send_ping(self) -> EmptyResult:
72+
async def send_ping(self) -> types.EmptyResult:
10273
"""Send a ping request."""
103-
from mcp.types import PingRequest
104-
10574
return await self.send_request(
106-
ClientRequest(
107-
PingRequest(
75+
types.ClientRequest(
76+
types.PingRequest(
10877
method="ping",
10978
)
11079
),
111-
EmptyResult,
80+
types.EmptyResult,
11281
)
11382

11483
async def send_progress_notification(
11584
self, progress_token: str | int, progress: float, total: float | None = None
11685
) -> None:
11786
"""Send a progress notification."""
118-
from mcp.types import (
119-
ProgressNotification,
120-
ProgressNotificationParams,
121-
)
122-
12387
await self.send_notification(
124-
ClientNotification(
125-
ProgressNotification(
88+
types.ClientNotification(
89+
types.ProgressNotification(
12690
method="notifications/progress",
127-
params=ProgressNotificationParams(
91+
params=types.ProgressNotificationParams(
12892
progressToken=progress_token,
12993
progress=progress,
13094
total=total,
@@ -133,180 +97,137 @@ async def send_progress_notification(
13397
)
13498
)
13599

136-
async def set_logging_level(self, level: LoggingLevel) -> EmptyResult:
100+
async def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResult:
137101
"""Send a logging/setLevel request."""
138-
from mcp.types import (
139-
SetLevelRequest,
140-
SetLevelRequestParams,
141-
)
142-
143102
return await self.send_request(
144-
ClientRequest(
145-
SetLevelRequest(
103+
types.ClientRequest(
104+
types.SetLevelRequest(
146105
method="logging/setLevel",
147-
params=SetLevelRequestParams(level=level),
106+
params=types.SetLevelRequestParams(level=level),
148107
)
149108
),
150-
EmptyResult,
109+
types.EmptyResult,
151110
)
152111

153-
async def list_resources(self) -> ListResourcesResult:
112+
async def list_resources(self) -> types.ListResourcesResult:
154113
"""Send a resources/list request."""
155-
from mcp.types import (
156-
ListResourcesRequest,
157-
)
158-
159114
return await self.send_request(
160-
ClientRequest(
161-
ListResourcesRequest(
115+
types.ClientRequest(
116+
types.ListResourcesRequest(
162117
method="resources/list",
163118
)
164119
),
165-
ListResourcesResult,
120+
types.ListResourcesResult,
166121
)
167122

168-
async def read_resource(self, uri: AnyUrl) -> ReadResourceResult:
123+
async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult:
169124
"""Send a resources/read request."""
170-
from mcp.types import (
171-
ReadResourceRequest,
172-
ReadResourceRequestParams,
173-
)
174-
175125
return await self.send_request(
176-
ClientRequest(
177-
ReadResourceRequest(
126+
types.ClientRequest(
127+
types.ReadResourceRequest(
178128
method="resources/read",
179-
params=ReadResourceRequestParams(uri=uri),
129+
params=types.ReadResourceRequestParams(uri=uri),
180130
)
181131
),
182-
ReadResourceResult,
132+
types.ReadResourceResult,
183133
)
184134

185-
async def subscribe_resource(self, uri: AnyUrl) -> EmptyResult:
135+
async def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
186136
"""Send a resources/subscribe request."""
187-
from mcp.types import (
188-
SubscribeRequest,
189-
SubscribeRequestParams,
190-
)
191-
192137
return await self.send_request(
193-
ClientRequest(
194-
SubscribeRequest(
138+
types.ClientRequest(
139+
types.SubscribeRequest(
195140
method="resources/subscribe",
196-
params=SubscribeRequestParams(uri=uri),
141+
params=types.SubscribeRequestParams(uri=uri),
197142
)
198143
),
199-
EmptyResult,
144+
types.EmptyResult,
200145
)
201146

202-
async def unsubscribe_resource(self, uri: AnyUrl) -> EmptyResult:
147+
async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
203148
"""Send a resources/unsubscribe request."""
204-
from mcp.types import (
205-
UnsubscribeRequest,
206-
UnsubscribeRequestParams,
207-
)
208-
209149
return await self.send_request(
210-
ClientRequest(
211-
UnsubscribeRequest(
150+
types.ClientRequest(
151+
types.UnsubscribeRequest(
212152
method="resources/unsubscribe",
213-
params=UnsubscribeRequestParams(uri=uri),
153+
params=types.UnsubscribeRequestParams(uri=uri),
214154
)
215155
),
216-
EmptyResult,
156+
types.EmptyResult,
217157
)
218158

219159
async def call_tool(
220160
self, name: str, arguments: dict | None = None
221-
) -> CallToolResult:
161+
) -> types.CallToolResult:
222162
"""Send a tools/call request."""
223-
from mcp.types import (
224-
CallToolRequest,
225-
CallToolRequestParams,
226-
)
227-
228163
return await self.send_request(
229-
ClientRequest(
230-
CallToolRequest(
164+
types.ClientRequest(
165+
types.CallToolRequest(
231166
method="tools/call",
232-
params=CallToolRequestParams(name=name, arguments=arguments),
167+
params=types.CallToolRequestParams(name=name, arguments=arguments),
233168
)
234169
),
235-
CallToolResult,
170+
types.CallToolResult,
236171
)
237172

238-
async def list_prompts(self) -> ListPromptsResult:
173+
async def list_prompts(self) -> types.ListPromptsResult:
239174
"""Send a prompts/list request."""
240-
from mcp.types import ListPromptsRequest
241-
242175
return await self.send_request(
243-
ClientRequest(
244-
ListPromptsRequest(
176+
types.ClientRequest(
177+
types.ListPromptsRequest(
245178
method="prompts/list",
246179
)
247180
),
248-
ListPromptsResult,
181+
types.ListPromptsResult,
249182
)
250183

251184
async def get_prompt(
252185
self, name: str, arguments: dict[str, str] | None = None
253-
) -> GetPromptResult:
186+
) -> types.GetPromptResult:
254187
"""Send a prompts/get request."""
255-
from mcp.types import GetPromptRequest, GetPromptRequestParams
256-
257188
return await self.send_request(
258-
ClientRequest(
259-
GetPromptRequest(
189+
types.ClientRequest(
190+
types.GetPromptRequest(
260191
method="prompts/get",
261-
params=GetPromptRequestParams(name=name, arguments=arguments),
192+
params=types.GetPromptRequestParams(name=name, arguments=arguments),
262193
)
263194
),
264-
GetPromptResult,
195+
types.GetPromptResult,
265196
)
266197

267198
async def complete(
268-
self, ref: ResourceReference | PromptReference, argument: dict
269-
) -> CompleteResult:
199+
self, ref: types.ResourceReference | types.PromptReference, argument: dict
200+
) -> types.CompleteResult:
270201
"""Send a completion/complete request."""
271-
from mcp.types import (
272-
CompleteRequest,
273-
CompleteRequestParams,
274-
CompletionArgument,
275-
)
276-
277202
return await self.send_request(
278-
ClientRequest(
279-
CompleteRequest(
203+
types.ClientRequest(
204+
types.CompleteRequest(
280205
method="completion/complete",
281-
params=CompleteRequestParams(
206+
params=types.CompleteRequestParams(
282207
ref=ref,
283-
argument=CompletionArgument(**argument),
208+
argument=types.CompletionArgument(**argument),
284209
),
285210
)
286211
),
287-
CompleteResult,
212+
types.CompleteResult,
288213
)
289214

290-
async def list_tools(self) -> ListToolsResult:
215+
async def list_tools(self) -> types.ListToolsResult:
291216
"""Send a tools/list request."""
292-
from mcp.types import ListToolsRequest
293-
294217
return await self.send_request(
295-
ClientRequest(
296-
ListToolsRequest(
218+
types.ClientRequest(
219+
types.ListToolsRequest(
297220
method="tools/list",
298221
)
299222
),
300-
ListToolsResult,
223+
types.ListToolsResult,
301224
)
302225

303226
async def send_roots_list_changed(self) -> None:
304227
"""Send a roots/list_changed notification."""
305-
from mcp.types import RootsListChangedNotification
306-
307228
await self.send_notification(
308-
ClientNotification(
309-
RootsListChangedNotification(
229+
types.ClientNotification(
230+
types.RootsListChangedNotification(
310231
method="notifications/roots/list_changed",
311232
)
312233
)

0 commit comments

Comments
 (0)