Skip to content

Commit dd72b63

Browse files
author
姚临倾
committed
Remove rust-toolchain and enhance tool macro functionality
1. Remove rust-toolchain.toml file - Project doesn't need to restrict specific Rust version - No dependency on components not included in default Rust installation 2. Enhance tool macro functionality - Automatically recognize #[tool(param)] and #[tool(aggr)] parameters, no manual marking needed - Add tool_impl_item functionality, supporting default implementation of ServerHandler trait for structs - Add default_build parameter (defaults to true), can be disabled with default_build=false 3. Update tests and example code - Update calculator and other examples to demonstrate new features - Modify related test cases These improvements simplify API usage and reduce manual configuration needs. In most cases, users only need to mark #[tool(tool_box)] to get default implementation of the ServerHandler trait.
1 parent afb8a90 commit dd72b63

36 files changed

+293
-287
lines changed

crates/rmcp-macros/.DS_Store

6 KB
Binary file not shown.

crates/rmcp-macros/src/lib.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22
use proc_macro::TokenStream;
33

44
mod tool;
5-
65
#[proc_macro_attribute]
76
pub fn tool(attr: TokenStream, input: TokenStream) -> TokenStream {
87
tool::tool(attr.into(), input.into())
98
.unwrap_or_else(|err| err.to_compile_error())
109
.into()
11-
}
10+
}

crates/rmcp-macros/src/tool.rs

Lines changed: 175 additions & 119 deletions
Large diffs are not rendered by default.

crates/rmcp/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ rand = { version = "0.9", optional = true }
5555
tokio-stream = { version = "0.1", optional = true }
5656

5757
# macro
58-
rmcp-macros = { version = "0.1", workspace = true, optional = true }
58+
rmcp-macros = { workspace = true, optional = true }
5959

6060
[features]
6161
default = ["base64", "macros", "server"]

crates/rmcp/src/error.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ impl Display for ErrorData {
88
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99
write!(f, "{}: {}", self.code.0, self.message)?;
1010
if let Some(data) = &self.data {
11-
write!(f, "({})", data)?;
11+
write!(f, "({data})")?;
1212
}
1313
Ok(())
1414
}

crates/rmcp/src/handler/server/tool.rs

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,33 @@ pub fn cached_schema_for_type<T: JsonSchema + std::any::Any>() -> Arc<JsonObject
4646
})
4747
}
4848

49+
/// Like cached_schema_for_type, but ensures the schema type is "object"
50+
pub fn ensure_object_schema_for_type<T: JsonSchema + std::any::Any>() -> Arc<JsonObject> {
51+
let schema = cached_schema_for_type::<T>();
52+
// Create a mutable clone of the schema
53+
let mut schema_clone = (*schema).clone();
54+
55+
// Ensure the schema has type: "object"
56+
if let Some(schema_type) = schema_clone.get_mut("type") {
57+
*schema_type = serde_json::Value::String("object".to_string());
58+
} else {
59+
schema_clone.insert("type".to_string(), serde_json::Value::String("object".to_string()));
60+
}
61+
62+
Arc::new(schema_clone)
63+
}
64+
65+
#[derive(Deserialize, Serialize, Debug,schemars::JsonSchema)]
66+
struct A{
67+
i:i32,
68+
b:i32,
69+
}
70+
4971
/// Deserialize a JSON object into a type
5072
pub fn parse_json_object<T: DeserializeOwned>(input: JsonObject) -> Result<T, crate::Error> {
5173
serde_json::from_value(serde_json::Value::Object(input)).map_err(|e| {
5274
crate::Error::invalid_params(
53-
format!("failed to deserialize parameters: {error}", error = e),
75+
format!("failed to deserialize parameters: {e}"),
5476
None,
5577
)
5678
})
@@ -259,7 +281,7 @@ where
259281
let value: P =
260282
serde_json::from_value(serde_json::Value::Object(arguments)).map_err(|e| {
261283
crate::Error::invalid_params(
262-
format!("failed to deserialize parameters: {error}", error = e),
284+
format!("failed to deserialize parameters: {e}"),
263285
None,
264286
)
265287
})?;

crates/rmcp/src/model.rs

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,7 @@ const_string!(InitializedNotificationMethod = "notifications/initialized");
499499
pub type InitializedNotification = NotificationNoParam<InitializedNotificationMethod>;
500500
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
501501
#[serde(rename_all = "camelCase")]
502+
#[derive(Default)]
502503
pub struct InitializeRequestParam {
503504
pub protocol_version: ProtocolVersion,
504505
pub capabilities: ClientCapabilities,
@@ -507,6 +508,7 @@ pub struct InitializeRequestParam {
507508

508509
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
509510
#[serde(rename_all = "camelCase")]
511+
#[derive(Default)]
510512
pub struct InitializeResult {
511513
pub protocol_version: ProtocolVersion,
512514
pub capabilities: ServerCapabilities,
@@ -518,26 +520,7 @@ pub struct InitializeResult {
518520
pub type ServerInfo = InitializeResult;
519521
pub type ClientInfo = InitializeRequestParam;
520522

521-
impl Default for ServerInfo {
522-
fn default() -> Self {
523-
ServerInfo {
524-
protocol_version: ProtocolVersion::default(),
525-
capabilities: ServerCapabilities::default(),
526-
server_info: Implementation::from_build_env(),
527-
instructions: None,
528-
}
529-
}
530-
}
531523

532-
impl Default for ClientInfo {
533-
fn default() -> Self {
534-
ClientInfo {
535-
protocol_version: ProtocolVersion::default(),
536-
capabilities: ClientCapabilities::default(),
537-
client_info: Implementation::from_build_env(),
538-
}
539-
}
540-
}
541524

542525
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
543526
pub struct Implementation {

crates/rmcp/src/service.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -546,7 +546,7 @@ where
546546
let serve_loop_ct = ct.child_token();
547547
let peer_return: Peer<R> = peer.clone();
548548
let handle = tokio::spawn(async move {
549-
let (mut sink, mut stream) = transport.into_transport();
549+
let (sink, stream) = transport.into_transport();
550550
let mut sink = std::pin::pin!(sink);
551551
let mut stream = std::pin::pin!(stream);
552552
let mut batch_messages = VecDeque::<RxJsonRpcMessage<R>>::new();

crates/rmcp/src/service/client.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ where
4949
.next()
5050
.await
5151
.ok_or_else(|| ClientError::ConnectionClosed(context.to_string()))
52-
.map_err(|e| ClientError::Io(std::io::Error::new(std::io::ErrorKind::Other, e)))
52+
.map_err(|e| ClientError::Io(std::io::Error::other(e)))
5353
}
5454

5555
/// Helper function to expect a response from the stream
@@ -132,7 +132,7 @@ where
132132
let handle_client_error = |e: ClientError| -> E {
133133
match e {
134134
ClientError::Io(io_err) => io_err.into(),
135-
other => std::io::Error::new(std::io::ErrorKind::Other, format!("{}", other)).into(),
135+
other => std::io::Error::other(format!("{other}")).into(),
136136
}
137137
};
138138

crates/rmcp/src/service/server.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ where
141141
let handle_server_error = |e: ServerError| -> E {
142142
match e {
143143
ServerError::Io(io_err) => io_err.into(),
144-
other => std::io::Error::new(std::io::ErrorKind::Other, format!("{}", other)).into(),
144+
other => std::io::Error::other(format!("{other}")).into(),
145145
}
146146
};
147147

crates/rmcp/src/transport/auth.rs

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ impl AuthorizationManager {
174174
.json::<AuthorizationMetadata>()
175175
.await
176176
.map_err(|e| {
177-
AuthError::MetadataError(format!("Failed to parse metadata: {}", e))
177+
AuthError::MetadataError(format!("Failed to parse metadata: {e}"))
178178
})?;
179179
debug!("metadata: {:?}", metadata);
180180
Ok(metadata)
@@ -185,9 +185,9 @@ impl AuthorizationManager {
185185
auth_base.set_path("");
186186

187187
Ok(AuthorizationMetadata {
188-
authorization_endpoint: format!("{}/authorize", auth_base),
189-
token_endpoint: format!("{}/token", auth_base),
190-
registration_endpoint: format!("{}/register", auth_base),
188+
authorization_endpoint: format!("{auth_base}/authorize"),
189+
token_endpoint: format!("{auth_base}/token"),
190+
registration_endpoint: format!("{auth_base}/register"),
191191
issuer: None,
192192
jwks_uri: None,
193193
scopes_supported: None,
@@ -204,15 +204,15 @@ impl AuthorizationManager {
204204
let metadata = self.metadata.as_ref().unwrap();
205205

206206
let auth_url = AuthUrl::new(metadata.authorization_endpoint.clone())
207-
.map_err(|e| AuthError::OAuthError(format!("Invalid authorization URL: {}", e)))?;
207+
.map_err(|e| AuthError::OAuthError(format!("Invalid authorization URL: {e}")))?;
208208

209209
let token_url = TokenUrl::new(metadata.token_endpoint.clone())
210-
.map_err(|e| AuthError::OAuthError(format!("Invalid token URL: {}", e)))?;
210+
.map_err(|e| AuthError::OAuthError(format!("Invalid token URL: {e}")))?;
211211

212212
// debug!("token url: {:?}", token_url);
213213
let client_id = ClientId::new(config.client_id);
214214
let redirect_url = RedirectUrl::new(config.redirect_uri.clone())
215-
.map_err(|e| AuthError::OAuthError(format!("Invalid re URL: {}", e)))?;
215+
.map_err(|e| AuthError::OAuthError(format!("Invalid re URL: {e}")))?;
216216

217217
debug!("client_id: {:?}", client_id);
218218
let mut client_builder = BasicClient::new(client_id.clone())
@@ -268,8 +268,7 @@ impl AuthorizationManager {
268268
Err(e) => {
269269
error!("Registration request failed: {}", e);
270270
return Err(AuthError::RegistrationFailed(format!(
271-
"HTTP request error: {}",
272-
e
271+
"HTTP request error: {e}"
273272
)));
274273
}
275274
};
@@ -283,8 +282,7 @@ impl AuthorizationManager {
283282

284283
error!("Registration failed: HTTP {} - {}", status, error_text);
285284
return Err(AuthError::RegistrationFailed(format!(
286-
"HTTP {}: {}",
287-
status, error_text
285+
"HTTP {status}: {error_text}"
288286
)));
289287
}
290288
debug!("registration response: {:?}", response);
@@ -293,8 +291,7 @@ impl AuthorizationManager {
293291
Err(e) => {
294292
error!("Failed to parse registration response: {}", e);
295293
return Err(AuthError::RegistrationFailed(format!(
296-
"analyze response error: {}",
297-
e
294+
"analyze response error: {e}"
298295
)));
299296
}
300297
};
@@ -451,7 +448,7 @@ impl AuthorizationManager {
451448
request: reqwest::RequestBuilder,
452449
) -> Result<reqwest::RequestBuilder, AuthError> {
453450
let token = self.get_access_token().await?;
454-
Ok(request.header(AUTHORIZATION, format!("Bearer {}", token)))
451+
Ok(request.header(AUTHORIZATION, format!("Bearer {token}")))
455452
}
456453

457454
/// handle response, check if need to re-authorize
@@ -497,7 +494,7 @@ impl AuthorizationSession {
497494
{
498495
Ok(config) => config,
499496
Err(e) => {
500-
eprintln!("Dynamic registration failed: {}", e);
497+
eprintln!("Dynamic registration failed: {e}");
501498
// fallback to default config
502499
config
503500
}

crates/rmcp/src/transport/sse.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ impl<C: SseClient<E>, E: std::error::Error + Send + Sync + 'static> Stream for S
307307
if let Some(max_retry_times) = retry_config.max_times {
308308
if *times >= max_retry_times {
309309
self.as_mut().state = SseTransportState::Fatal {
310-
reason: format!("retrying failed after {} times: {}", times, e),
310+
reason: format!("retrying failed after {times} times: {e}"),
311311
};
312312
return self.poll_next(cx);
313313
}

crates/rmcp/src/transport/sse_auth.rs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ impl SseClient<reqwest::Error> for AuthorizedSseClient {
8787
let mut request_builder = client
8888
.get(&sse_url)
8989
.header(ACCEPT, MIME_TYPE)
90-
.header(AUTHORIZATION, format!("Bearer {}", token));
90+
.header(AUTHORIZATION, format!("Bearer {token}"));
9191

9292
if let Some(last_event_id) = last_event_id {
9393
request_builder = request_builder.header(HEADER_LAST_EVENT_ID, last_event_id);
@@ -136,7 +136,7 @@ impl SseClient<reqwest::Error> for AuthorizedSseClient {
136136
let uri = sse_url.join(&session_id).map_err(SseTransportError::from)?;
137137
let request_builder = client
138138
.post(uri.as_ref())
139-
.header(AUTHORIZATION, format!("Bearer {}", token))
139+
.header(AUTHORIZATION, format!("Bearer {token}"))
140140
.json(&message);
141141

142142
request_builder
@@ -151,8 +151,7 @@ impl SseClient<reqwest::Error> for AuthorizedSseClient {
151151

152152
impl From<AuthError> for SseTransportError<reqwest::Error> {
153153
fn from(err: AuthError) -> Self {
154-
SseTransportError::Io(std::io::Error::new(
155-
std::io::ErrorKind::Other,
154+
SseTransportError::Io(std::io::Error::other(
156155
err.to_string(),
157156
))
158157
}
@@ -175,8 +174,7 @@ where
175174
transport.retry_config = retry_config.unwrap_or_default();
176175
Ok(transport)
177176
}
178-
_ => Err(SseTransportError::Io(std::io::Error::new(
179-
std::io::ErrorKind::Other,
177+
_ => Err(SseTransportError::Io(std::io::Error::other(
180178
"Not authorized".to_string(),
181179
))),
182180
}
Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
use rmcp::{
2-
ServerHandler,
3-
model::{ServerCapabilities, ServerInfo},
42
schemars, tool,
53
};
64
#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
@@ -11,34 +9,21 @@ pub struct SumRequest {
119
}
1210
#[derive(Debug, Clone, Default)]
1311
pub struct Calculator;
14-
#[tool(tool_box)]
12+
#[tool(tool_box,description = "A simple calculator")]
1513
impl Calculator {
1614
#[tool(description = "Calculate the sum of two numbers")]
17-
fn sum(&self, #[tool(aggr)] SumRequest { a, b }: SumRequest) -> String {
15+
fn sum(&self, SumRequest { a, b }: SumRequest) -> String {
1816
(a + b).to_string()
1917
}
2018

2119
#[tool(description = "Calculate the sub of two numbers")]
2220
fn sub(
2321
&self,
24-
#[tool(param)]
2522
#[schemars(description = "the left hand side number")]
2623
a: i32,
27-
#[tool(param)]
2824
#[schemars(description = "the right hand side number")]
2925
b: i32,
3026
) -> String {
3127
(a - b).to_string()
3228
}
3329
}
34-
35-
#[tool(tool_box)]
36-
impl ServerHandler for Calculator {
37-
fn get_info(&self) -> ServerInfo {
38-
ServerInfo {
39-
instructions: Some("A simple calculator".into()),
40-
capabilities: ServerCapabilities::builder().enable_tools().build(),
41-
..Default::default()
42-
}
43-
}
44-
}

crates/rmcp/tests/common/handlers.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ impl ClientHandler for TestClientHandler {
100100
let received_messages = self.received_messages.clone();
101101

102102
async move {
103-
println!("Client: Received log message: {:?}", params);
103+
println!("Client: Received log message: {params:?}");
104104
let mut messages = received_messages.lock().unwrap();
105105
messages.push(params);
106106
receive_signal.notify_one();
@@ -185,7 +185,7 @@ impl ServerHandler for TestServer {
185185
})
186186
.await
187187
{
188-
panic!("Failed to send notification: {}", e);
188+
panic!("Failed to send notification: {e}");
189189
}
190190
Ok(())
191191
}

crates/rmcp/tests/test_complex_schema.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ impl Demo {
3333
#[tool(description = "LLM")]
3434
async fn chat(
3535
&self,
36-
#[tool(aggr)] chat_request: ChatRequest,
36+
chat_request: ChatRequest,
3737
) -> Result<CallToolResult, McpError> {
3838
let content = Content::json(chat_request)?;
3939
Ok(CallToolResult::success(vec![content]))

crates/rmcp/tests/test_logging.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -210,19 +210,17 @@ fn test_logging_level_serialization() {
210210
let serialized = serialized.trim_matches('"');
211211
assert_eq!(
212212
serialized, expected,
213-
"LoggingLevel::{:?} should serialize to \"{}\"",
214-
level, expected
213+
"LoggingLevel::{level:?} should serialize to \"{expected}\""
215214
);
216215
}
217216

218217
// Test deserialization from spec strings
219218
for (level, spec_string) in test_cases {
220219
let deserialized: LoggingLevel =
221-
serde_json::from_str(&format!("\"{}\"", spec_string)).unwrap();
220+
serde_json::from_str(&format!("\"{spec_string}\"")).unwrap();
222221
assert_eq!(
223222
deserialized, level,
224-
"\"{}\" should deserialize to LoggingLevel::{:?}",
225-
spec_string, level
223+
"\"{spec_string}\" should deserialize to LoggingLevel::{level:?}"
226224
);
227225
}
228226
}

crates/rmcp/tests/test_notification.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ impl ServerHandler for Server {
3939
.notify_resource_updated(ResourceUpdatedNotificationParam { uri: uri.clone() })
4040
.await
4141
{
42-
panic!("Failed to send notification: {}", e);
42+
panic!("Failed to send notification: {e}");
4343
}
4444
});
4545

0 commit comments

Comments
 (0)