Skip to content

Commit 05b7156

Browse files
authored
Support for http request injection propagation in StreamableHttp (#833)
1 parent 7f94bef commit 05b7156

File tree

3 files changed

+259
-172
lines changed

3 files changed

+259
-172
lines changed

src/mcp/server/streamable_http.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,8 @@ async def _handle_post_request(
397397
await response(scope, receive, send)
398398

399399
# Process the message after sending the response
400-
session_message = SessionMessage(message)
400+
metadata = ServerMessageMetadata(request_context=request)
401+
session_message = SessionMessage(message, metadata=metadata)
401402
await writer.send(session_message)
402403

403404
return
@@ -412,7 +413,8 @@ async def _handle_post_request(
412413

413414
if self.is_json_response_enabled:
414415
# Process the message
415-
session_message = SessionMessage(message)
416+
metadata = ServerMessageMetadata(request_context=request)
417+
session_message = SessionMessage(message, metadata=metadata)
416418
await writer.send(session_message)
417419
try:
418420
# Process messages from the request-specific stream
@@ -511,7 +513,8 @@ async def sse_writer():
511513
async with anyio.create_task_group() as tg:
512514
tg.start_soon(response, scope, receive, send)
513515
# Then send the message to be processed by the server
514-
session_message = SessionMessage(message)
516+
metadata = ServerMessageMetadata(request_context=request)
517+
session_message = SessionMessage(message, metadata=metadata)
515518
await writer.send(session_message)
516519
except Exception:
517520
logger.exception("SSE response error")

tests/server/fastmcp/test_integration.py

Lines changed: 51 additions & 169 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from mcp.client.streamable_http import streamablehttp_client
2525
from mcp.server.fastmcp import FastMCP
2626
from mcp.server.fastmcp.resources import FunctionResource
27-
from mcp.server.fastmcp.server import Context
2827
from mcp.shared.context import RequestContext
2928
from mcp.types import (
3029
CreateMessageRequestParams,
@@ -196,6 +195,33 @@ def complex_prompt(user_query: str, context: str = "general") -> str:
196195
# Since FastMCP doesn't support system messages in the same way
197196
return f"Context: {context}. Query: {user_query}"
198197

198+
# Tool that echoes request headers from context
199+
@mcp.tool(description="Echo request headers from context")
200+
def echo_headers(ctx: Context[Any, Any, Request]) -> str:
201+
"""Returns the request headers as JSON."""
202+
headers_info = {}
203+
if ctx.request_context.request:
204+
# Now the type system knows request is a Starlette Request object
205+
headers_info = dict(ctx.request_context.request.headers)
206+
return json.dumps(headers_info)
207+
208+
# Tool that returns full request context
209+
@mcp.tool(description="Echo request context with custom data")
210+
def echo_context(custom_request_id: str, ctx: Context[Any, Any, Request]) -> str:
211+
"""Returns request context including headers and custom data."""
212+
context_data = {
213+
"custom_request_id": custom_request_id,
214+
"headers": {},
215+
"method": None,
216+
"path": None,
217+
}
218+
if ctx.request_context.request:
219+
request = ctx.request_context.request
220+
context_data["headers"] = dict(request.headers)
221+
context_data["method"] = request.method
222+
context_data["path"] = request.url.path
223+
return json.dumps(context_data)
224+
199225
return mcp
200226

201227

@@ -432,174 +458,6 @@ async def test_fastmcp_without_auth(server: None, server_url: str) -> None:
432458
assert tool_result.content[0].text == "Echo: hello"
433459

434460

435-
def make_fastmcp_with_context_app():
436-
"""Create a FastMCP server that can access request context."""
437-
438-
mcp = FastMCP(name="ContextServer")
439-
440-
# Tool that echoes request headers
441-
@mcp.tool(description="Echo request headers from context")
442-
def echo_headers(ctx: Context[Any, Any, Request]) -> str:
443-
"""Returns the request headers as JSON."""
444-
headers_info = {}
445-
if ctx.request_context.request:
446-
# Now the type system knows request is a Starlette Request object
447-
headers_info = dict(ctx.request_context.request.headers)
448-
return json.dumps(headers_info)
449-
450-
# Tool that returns full request context
451-
@mcp.tool(description="Echo request context with custom data")
452-
def echo_context(custom_request_id: str, ctx: Context[Any, Any, Request]) -> str:
453-
"""Returns request context including headers and custom data."""
454-
context_data = {
455-
"custom_request_id": custom_request_id,
456-
"headers": {},
457-
"method": None,
458-
"path": None,
459-
}
460-
if ctx.request_context.request:
461-
request = ctx.request_context.request
462-
context_data["headers"] = dict(request.headers)
463-
context_data["method"] = request.method
464-
context_data["path"] = request.url.path
465-
return json.dumps(context_data)
466-
467-
# Create the SSE app
468-
app = mcp.sse_app()
469-
return mcp, app
470-
471-
472-
def run_context_server(server_port: int) -> None:
473-
"""Run the context-aware FastMCP server."""
474-
_, app = make_fastmcp_with_context_app()
475-
server = uvicorn.Server(
476-
config=uvicorn.Config(
477-
app=app, host="127.0.0.1", port=server_port, log_level="error"
478-
)
479-
)
480-
print(f"Starting context server on port {server_port}")
481-
server.run()
482-
483-
484-
@pytest.fixture()
485-
def context_aware_server(server_port: int) -> Generator[None, None, None]:
486-
"""Start the context-aware server in a separate process."""
487-
proc = multiprocessing.Process(
488-
target=run_context_server, args=(server_port,), daemon=True
489-
)
490-
print("Starting context-aware server process")
491-
proc.start()
492-
493-
# Wait for server to be running
494-
max_attempts = 20
495-
attempt = 0
496-
print("Waiting for context-aware server to start")
497-
while attempt < max_attempts:
498-
try:
499-
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
500-
s.connect(("127.0.0.1", server_port))
501-
break
502-
except ConnectionRefusedError:
503-
time.sleep(0.1)
504-
attempt += 1
505-
else:
506-
raise RuntimeError(
507-
f"Context server failed to start after {max_attempts} attempts"
508-
)
509-
510-
yield
511-
512-
print("Killing context-aware server")
513-
proc.kill()
514-
proc.join(timeout=2)
515-
if proc.is_alive():
516-
print("Context server process failed to terminate")
517-
518-
519-
@pytest.mark.anyio
520-
async def test_fast_mcp_with_request_context(
521-
context_aware_server: None, server_url: str
522-
) -> None:
523-
"""Test that FastMCP properly propagates request context to tools."""
524-
# Test with custom headers
525-
custom_headers = {
526-
"Authorization": "Bearer fastmcp-test-token",
527-
"X-Custom-Header": "fastmcp-value",
528-
"X-Request-Id": "req-123",
529-
}
530-
531-
async with sse_client(server_url + "/sse", headers=custom_headers) as streams:
532-
async with ClientSession(*streams) as session:
533-
# Initialize the session
534-
result = await session.initialize()
535-
assert isinstance(result, InitializeResult)
536-
assert result.serverInfo.name == "ContextServer"
537-
538-
# Test 1: Call tool that echoes headers
539-
headers_result = await session.call_tool("echo_headers", {})
540-
assert len(headers_result.content) == 1
541-
assert isinstance(headers_result.content[0], TextContent)
542-
543-
headers_data = json.loads(headers_result.content[0].text)
544-
assert headers_data.get("authorization") == "Bearer fastmcp-test-token"
545-
assert headers_data.get("x-custom-header") == "fastmcp-value"
546-
assert headers_data.get("x-request-id") == "req-123"
547-
548-
# Test 2: Call tool that returns full context
549-
context_result = await session.call_tool(
550-
"echo_context", {"custom_request_id": "test-123"}
551-
)
552-
assert len(context_result.content) == 1
553-
assert isinstance(context_result.content[0], TextContent)
554-
555-
context_data = json.loads(context_result.content[0].text)
556-
assert context_data["custom_request_id"] == "test-123"
557-
assert (
558-
context_data["headers"].get("authorization")
559-
== "Bearer fastmcp-test-token"
560-
)
561-
assert context_data["method"] == "POST" #
562-
563-
564-
@pytest.mark.anyio
565-
async def test_fast_mcp_request_context_isolation(
566-
context_aware_server: None, server_url: str
567-
) -> None:
568-
"""Test that request contexts are isolated between different FastMCP clients."""
569-
contexts = []
570-
571-
# Create multiple clients with different headers
572-
for i in range(3):
573-
headers = {
574-
"Authorization": f"Bearer token-{i}",
575-
"X-Request-Id": f"fastmcp-req-{i}",
576-
"X-Custom-Value": f"value-{i}",
577-
}
578-
579-
async with sse_client(server_url + "/sse", headers=headers) as streams:
580-
async with ClientSession(*streams) as session:
581-
await session.initialize()
582-
583-
# Call the tool that returns context
584-
tool_result = await session.call_tool(
585-
"echo_context", {"custom_request_id": f"test-req-{i}"}
586-
)
587-
588-
# Parse and store the result
589-
assert len(tool_result.content) == 1
590-
assert isinstance(tool_result.content[0], TextContent)
591-
context_data = json.loads(tool_result.content[0].text)
592-
contexts.append(context_data)
593-
594-
# Verify each request had its own isolated context
595-
assert len(contexts) == 3
596-
for i, ctx in enumerate(contexts):
597-
assert ctx["custom_request_id"] == f"test-req-{i}"
598-
assert ctx["headers"].get("authorization") == f"Bearer token-{i}"
599-
assert ctx["headers"].get("x-request-id") == f"fastmcp-req-{i}"
600-
assert ctx["headers"].get("x-custom-value") == f"value-{i}"
601-
602-
603461
@pytest.mark.anyio
604462
async def test_fastmcp_streamable_http(
605463
streamable_http_server: None, http_server_url: str
@@ -967,6 +825,30 @@ async def progress_callback(
967825
assert isinstance(complex_result, GetPromptResult)
968826
assert len(complex_result.messages) >= 1
969827

828+
# Test request context propagation (only works when headers are available)
829+
830+
headers_result = await session.call_tool("echo_headers", {})
831+
assert len(headers_result.content) == 1
832+
assert isinstance(headers_result.content[0], TextContent)
833+
834+
# If we got headers, verify they exist
835+
headers_data = json.loads(headers_result.content[0].text)
836+
# The headers depend on the transport and test setup
837+
print(f"Received headers: {headers_data}")
838+
839+
# Test 6: Call tool that returns full context
840+
context_result = await session.call_tool(
841+
"echo_context", {"custom_request_id": "test-123"}
842+
)
843+
assert len(context_result.content) == 1
844+
assert isinstance(context_result.content[0], TextContent)
845+
846+
context_data = json.loads(context_result.content[0].text)
847+
assert context_data["custom_request_id"] == "test-123"
848+
# The method should be POST for most transports
849+
if context_data["method"]:
850+
assert context_data["method"] == "POST"
851+
970852

971853
async def sampling_callback(
972854
context: RequestContext[ClientSession, None],

0 commit comments

Comments
 (0)