diff --git a/crates/rmcp/src/transport/streamable_http_server/tower.rs b/crates/rmcp/src/transport/streamable_http_server/tower.rs index 4bcfe0a6..d4d1d17f 100644 --- a/crates/rmcp/src/transport/streamable_http_server/tower.rs +++ b/crates/rmcp/src/transport/streamable_http_server/tower.rs @@ -47,7 +47,7 @@ impl Default for StreamableHttpServerConfig { pub struct StreamableHttpService { pub config: StreamableHttpServerConfig, session_manager: Arc, - service_factory: Arc S + Send + Sync>, + service_factory: Arc Result + Send + Sync>, } impl Clone for StreamableHttpService { @@ -92,7 +92,7 @@ where M: SessionManager, { pub fn new( - service_factory: impl Fn() -> S + Send + Sync + 'static, + service_factory: impl Fn() -> Result + Send + Sync + 'static, session_manager: Arc, config: StreamableHttpServerConfig, ) -> Self { @@ -102,7 +102,7 @@ where service_factory: Arc::new(service_factory), } } - fn get_service(&self) -> S { + fn get_service(&self) -> Result { (self.service_factory)() } pub async fn handle(&self, request: Request) -> Response> @@ -318,7 +318,9 @@ where .create_session() .await .map_err(internal_error_response("create session"))?; - let service = self.get_service(); + let service = self + .get_service() + .map_err(internal_error_response("get service"))?; // spawn a task to serve the session tokio::spawn({ let session_manager = self.session_manager.clone(); @@ -372,7 +374,9 @@ where Ok(response) } } else { - let service = self.get_service(); + let service = self + .get_service() + .map_err(internal_error_response("get service"))?; match message { ClientJsonRpcMessage::Request(request) => { let (transport, receiver) = diff --git a/crates/rmcp/tests/test_with_js.rs b/crates/rmcp/tests/test_with_js.rs index 0272050d..b00752ad 100644 --- a/crates/rmcp/tests/test_with_js.rs +++ b/crates/rmcp/tests/test_with_js.rs @@ -96,7 +96,7 @@ async fn test_with_js_streamable_http_client() -> anyhow::Result<()> { let service: StreamableHttpService = StreamableHttpService::new( - Calculator::default, + || Ok(Calculator), Default::default(), StreamableHttpServerConfig { stateful_mode: true, diff --git a/examples/servers/src/counter_hyper_streamable_http.rs b/examples/servers/src/counter_hyper_streamable_http.rs index c9d2a3e6..6312180d 100644 --- a/examples/servers/src/counter_hyper_streamable_http.rs +++ b/examples/servers/src/counter_hyper_streamable_http.rs @@ -12,7 +12,7 @@ use rmcp::transport::streamable_http_server::{ #[tokio::main] async fn main() -> anyhow::Result<()> { let service = TowerToHyperService::new(StreamableHttpService::new( - Counter::new, + || Ok(Counter::new()), LocalSessionManager::default().into(), Default::default(), )); diff --git a/examples/servers/src/counter_streamhttp.rs b/examples/servers/src/counter_streamhttp.rs index f4fa1d6c..ff00cec6 100644 --- a/examples/servers/src/counter_streamhttp.rs +++ b/examples/servers/src/counter_streamhttp.rs @@ -22,7 +22,7 @@ async fn main() -> anyhow::Result<()> { .init(); let service = StreamableHttpService::new( - Counter::new, + || Ok(Counter::new()), LocalSessionManager::default().into(), Default::default(), );