Skip to content

Commit 70014a2

Browse files
authored
Support for http request injection propagation to tools (#816)
1 parent 532b117 commit 70014a2

File tree

12 files changed

+413
-35
lines changed

12 files changed

+413
-35
lines changed

src/mcp/server/fastmcp/server.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
from mcp.server.stdio import stdio_server
5050
from mcp.server.streamable_http import EventStore
5151
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
52-
from mcp.shared.context import LifespanContextT, RequestContext
52+
from mcp.shared.context import LifespanContextT, RequestContext, RequestT
5353
from mcp.types import (
5454
AnyFunction,
5555
EmbeddedResource,
@@ -124,9 +124,11 @@ class Settings(BaseSettings, Generic[LifespanResultT]):
124124
def lifespan_wrapper(
125125
app: FastMCP,
126126
lifespan: Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]],
127-
) -> Callable[[MCPServer[LifespanResultT]], AbstractAsyncContextManager[object]]:
127+
) -> Callable[
128+
[MCPServer[LifespanResultT, Request]], AbstractAsyncContextManager[object]
129+
]:
128130
@asynccontextmanager
129-
async def wrap(s: MCPServer[LifespanResultT]) -> AsyncIterator[object]:
131+
async def wrap(s: MCPServer[LifespanResultT, Request]) -> AsyncIterator[object]:
130132
async with lifespan(app) as context:
131133
yield context
132134

@@ -260,7 +262,7 @@ async def list_tools(self) -> list[MCPTool]:
260262
for info in tools
261263
]
262264

263-
def get_context(self) -> Context[ServerSession, object]:
265+
def get_context(self) -> Context[ServerSession, object, Request]:
264266
"""
265267
Returns a Context object. Note that the context will only be valid
266268
during a request; outside a request, most methods will error.
@@ -893,7 +895,7 @@ def _convert_to_content(
893895
return [TextContent(type="text", text=result)]
894896

895897

896-
class Context(BaseModel, Generic[ServerSessionT, LifespanContextT]):
898+
class Context(BaseModel, Generic[ServerSessionT, LifespanContextT, RequestT]):
897899
"""Context object providing access to MCP capabilities.
898900
899901
This provides a cleaner interface to MCP's RequestContext functionality.
@@ -927,13 +929,15 @@ def my_tool(x: int, ctx: Context) -> str:
927929
The context is optional - tools that don't need it can omit the parameter.
928930
"""
929931

930-
_request_context: RequestContext[ServerSessionT, LifespanContextT] | None
932+
_request_context: RequestContext[ServerSessionT, LifespanContextT, RequestT] | None
931933
_fastmcp: FastMCP | None
932934

933935
def __init__(
934936
self,
935937
*,
936-
request_context: RequestContext[ServerSessionT, LifespanContextT] | None = None,
938+
request_context: (
939+
RequestContext[ServerSessionT, LifespanContextT, RequestT] | None
940+
) = None,
937941
fastmcp: FastMCP | None = None,
938942
**kwargs: Any,
939943
):
@@ -949,7 +953,9 @@ def fastmcp(self) -> FastMCP:
949953
return self._fastmcp
950954

951955
@property
952-
def request_context(self) -> RequestContext[ServerSessionT, LifespanContextT]:
956+
def request_context(
957+
self,
958+
) -> RequestContext[ServerSessionT, LifespanContextT, RequestT]:
953959
"""Access to the underlying request context."""
954960
if self._request_context is None:
955961
raise ValueError("Context is not available outside of a request")

src/mcp/server/fastmcp/tools/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
if TYPE_CHECKING:
1515
from mcp.server.fastmcp.server import Context
1616
from mcp.server.session import ServerSessionT
17-
from mcp.shared.context import LifespanContextT
17+
from mcp.shared.context import LifespanContextT, RequestT
1818

1919

2020
class Tool(BaseModel):
@@ -85,7 +85,7 @@ def from_function(
8585
async def run(
8686
self,
8787
arguments: dict[str, Any],
88-
context: Context[ServerSessionT, LifespanContextT] | None = None,
88+
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
8989
) -> Any:
9090
"""Run the tool with arguments."""
9191
try:

src/mcp/server/fastmcp/tools/tool_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from mcp.server.fastmcp.exceptions import ToolError
77
from mcp.server.fastmcp.tools.base import Tool
88
from mcp.server.fastmcp.utilities.logging import get_logger
9-
from mcp.shared.context import LifespanContextT
9+
from mcp.shared.context import LifespanContextT, RequestT
1010
from mcp.types import ToolAnnotations
1111

1212
if TYPE_CHECKING:
@@ -65,7 +65,7 @@ async def call_tool(
6565
self,
6666
name: str,
6767
arguments: dict[str, Any],
68-
context: Context[ServerSessionT, LifespanContextT] | None = None,
68+
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
6969
) -> Any:
7070
"""Call a tool by name with arguments."""
7171
tool = self.get_tool(name)

src/mcp/server/lowlevel/server.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,12 @@ async def main():
7272
import warnings
7373
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable
7474
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
75-
from typing import Any, Generic, TypeVar
75+
from typing import Any, Generic
7676

7777
import anyio
7878
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
7979
from pydantic import AnyUrl
80+
from typing_extensions import TypeVar
8081

8182
import mcp.types as types
8283
from mcp.server.lowlevel.helper_types import ReadResourceContents
@@ -85,15 +86,16 @@ async def main():
8586
from mcp.server.stdio import stdio_server as stdio_server
8687
from mcp.shared.context import RequestContext
8788
from mcp.shared.exceptions import McpError
88-
from mcp.shared.message import SessionMessage
89+
from mcp.shared.message import ServerMessageMetadata, SessionMessage
8990
from mcp.shared.session import RequestResponder
9091

9192
logger = logging.getLogger(__name__)
9293

9394
LifespanResultT = TypeVar("LifespanResultT")
95+
RequestT = TypeVar("RequestT", default=Any)
9496

9597
# This will be properly typed in each Server instance's context
96-
request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any]] = (
98+
request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any, Any]] = (
9799
contextvars.ContextVar("request_ctx")
98100
)
99101

@@ -111,7 +113,7 @@ def __init__(
111113

112114

113115
@asynccontextmanager
114-
async def lifespan(server: Server[LifespanResultT]) -> AsyncIterator[object]:
116+
async def lifespan(server: Server[LifespanResultT, RequestT]) -> AsyncIterator[object]:
115117
"""Default lifespan context manager that does nothing.
116118
117119
Args:
@@ -123,14 +125,15 @@ async def lifespan(server: Server[LifespanResultT]) -> AsyncIterator[object]:
123125
yield {}
124126

125127

126-
class Server(Generic[LifespanResultT]):
128+
class Server(Generic[LifespanResultT, RequestT]):
127129
def __init__(
128130
self,
129131
name: str,
130132
version: str | None = None,
131133
instructions: str | None = None,
132134
lifespan: Callable[
133-
[Server[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]
135+
[Server[LifespanResultT, RequestT]],
136+
AbstractAsyncContextManager[LifespanResultT],
134137
] = lifespan,
135138
):
136139
self.name = name
@@ -215,7 +218,9 @@ def get_capabilities(
215218
)
216219

217220
@property
218-
def request_context(self) -> RequestContext[ServerSession, LifespanResultT]:
221+
def request_context(
222+
self,
223+
) -> RequestContext[ServerSession, LifespanResultT, RequestT]:
219224
"""If called outside of a request context, this will raise a LookupError."""
220225
return request_ctx.get()
221226

@@ -555,6 +560,13 @@ async def _handle_request(
555560

556561
token = None
557562
try:
563+
# Extract request context from message metadata
564+
request_data = None
565+
if message.message_metadata is not None and isinstance(
566+
message.message_metadata, ServerMessageMetadata
567+
):
568+
request_data = message.message_metadata.request_context
569+
558570
# Set our global state that can be retrieved via
559571
# app.get_request_context()
560572
token = request_ctx.set(
@@ -563,6 +575,7 @@ async def _handle_request(
563575
message.request_meta,
564576
session,
565577
lifespan_context,
578+
request=request_data,
566579
)
567580
)
568581
response = await handler(req)

src/mcp/server/sse.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ async def handle_sse(request):
5252
from starlette.types import Receive, Scope, Send
5353

5454
import mcp.types as types
55-
from mcp.shared.message import SessionMessage
55+
from mcp.shared.message import ServerMessageMetadata, SessionMessage
5656

5757
logger = logging.getLogger(__name__)
5858

@@ -203,7 +203,9 @@ async def handle_post_message(
203203
await writer.send(err)
204204
return
205205

206-
session_message = SessionMessage(message)
206+
# Pass the ASGI scope for framework-agnostic access to request data
207+
metadata = ServerMessageMetadata(request_context=request)
208+
session_message = SessionMessage(message, metadata=metadata)
207209
logger.debug(f"Sending session message to writer: {session_message}")
208210
response = Response("Accepted", status_code=202)
209211
await response(scope, receive, send)

src/mcp/server/streamable_http_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class StreamableHTTPSessionManager:
5656

5757
def __init__(
5858
self,
59-
app: MCPServer[Any],
59+
app: MCPServer[Any, Any],
6060
event_store: EventStore | None = None,
6161
json_response: bool = False,
6262
stateless: bool = False,

src/mcp/shared/context.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@
88

99
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
1010
LifespanContextT = TypeVar("LifespanContextT")
11+
RequestT = TypeVar("RequestT", default=Any)
1112

1213

1314
@dataclass
14-
class RequestContext(Generic[SessionT, LifespanContextT]):
15+
class RequestContext(Generic[SessionT, LifespanContextT, RequestT]):
1516
request_id: RequestId
1617
meta: RequestParams.Meta | None
1718
session: SessionT
1819
lifespan_context: LifespanContextT
20+
request: RequestT | None = None

src/mcp/shared/message.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ class ServerMessageMetadata:
3030
"""Metadata specific to server messages."""
3131

3232
related_request_id: RequestId | None = None
33+
# Request-specific context (e.g., headers, auth info)
34+
request_context: object | None = None
3335

3436

3537
MessageMetadata = ClientMessageMetadata | ServerMessageMetadata | None

src/mcp/shared/session.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,12 @@ def __init__(
8181
ReceiveNotificationT
8282
]""",
8383
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
84+
message_metadata: MessageMetadata = None,
8485
) -> None:
8586
self.request_id = request_id
8687
self.request_meta = request_meta
8788
self.request = request
89+
self.message_metadata = message_metadata
8890
self._session = session
8991
self._completed = False
9092
self._cancel_scope = anyio.CancelScope()
@@ -365,6 +367,7 @@ async def _receive_loop(self) -> None:
365367
request=validated_request,
366368
session=self,
367369
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
370+
message_metadata=message.metadata,
368371
)
369372

370373
self._in_flight[responder.request_id] = responder

0 commit comments

Comments
 (0)