From 05bcf7eadb38e190a0109dcf532a299b9231eebe Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Fri, 30 May 2025 11:28:46 +0200 Subject: [PATCH 1/2] Use 120 characters instead of 88 --- examples/fastmcp/memory.py | 36 +--- examples/fastmcp/text_me.py | 8 +- examples/fastmcp/unicode_example.py | 5 +- .../simple-auth/mcp_simple_auth/server.py | 31 +-- pyproject.toml | 4 +- src/mcp/cli/claude.py | 12 +- src/mcp/cli/cli.py | 36 +--- src/mcp/client/__main__.py | 8 +- src/mcp/client/auth.py | 65 ++----- src/mcp/client/session.py | 71 ++----- src/mcp/client/session_group.py | 12 +- src/mcp/client/sse.py | 35 +--- src/mcp/client/stdio/__init__.py | 14 +- src/mcp/client/stdio/win32.py | 4 +- src/mcp/client/streamable_http.py | 52 ++--- src/mcp/client/websocket.py | 9 +- src/mcp/server/auth/errors.py | 5 +- src/mcp/server/auth/handlers/authorize.py | 49 ++--- src/mcp/server/auth/handlers/register.py | 17 +- src/mcp/server/auth/handlers/revoke.py | 11 +- src/mcp/server/auth/handlers/token.py | 63 ++---- .../server/auth/middleware/auth_context.py | 4 +- src/mcp/server/auth/middleware/bearer_auth.py | 17 +- src/mcp/server/auth/middleware/client_auth.py | 9 +- src/mcp/server/auth/provider.py | 17 +- src/mcp/server/auth/routes.py | 18 +- src/mcp/server/auth/settings.py | 3 +- src/mcp/server/fastmcp/prompts/base.py | 32 +--- src/mcp/server/fastmcp/prompts/manager.py | 4 +- src/mcp/server/fastmcp/resources/base.py | 8 +- src/mcp/server/fastmcp/resources/templates.py | 12 +- src/mcp/server/fastmcp/resources/types.py | 32 +--- src/mcp/server/fastmcp/server.py | 98 +++------- src/mcp/server/fastmcp/tools/base.py | 15 +- src/mcp/server/fastmcp/tools/tool_manager.py | 4 +- .../server/fastmcp/utilities/func_metadata.py | 22 +-- src/mcp/server/lowlevel/server.py | 69 ++----- src/mcp/server/session.py | 43 ++--- src/mcp/server/sse.py | 26 +-- src/mcp/server/stdio.py | 4 +- src/mcp/server/streamable_http.py | 141 ++++---------- src/mcp/server/streamable_http_manager.py | 13 +- src/mcp/server/streaming_asgi_transport.py | 20 +- src/mcp/server/websocket.py | 4 +- src/mcp/shared/auth.py | 29 +-- src/mcp/shared/memory.py | 25 +-- src/mcp/shared/message.py | 4 +- src/mcp/shared/progress.py | 34 +--- src/mcp/shared/session.py | 62 ++---- src/mcp/types.py | 82 ++------ tests/client/conftest.py | 18 +- tests/client/test_auth.py | 180 +++++------------- tests/client/test_config.py | 4 +- tests/client/test_list_methods_cursor.py | 16 +- tests/client/test_list_roots_callback.py | 17 +- tests/client/test_logging_callback.py | 4 +- tests/client/test_sampling_callback.py | 27 +-- tests/client/test_session.py | 104 +++------- tests/client/test_session_group.py | 60 ++---- tests/client/test_stdio.py | 12 +- tests/issues/test_100_tool_listing.py | 8 +- tests/issues/test_129_resource_templates.py | 4 +- tests/issues/test_141_resource_templates.py | 12 +- tests/issues/test_152_resource_mime_type.py | 56 ++---- tests/issues/test_176_progress_token.py | 16 +- tests/issues/test_192_request_id.py | 16 +- tests/issues/test_342_base64_encoding.py | 6 +- tests/issues/test_88_random_error.py | 8 +- .../auth/middleware/test_bearer_auth.py | 20 +- tests/server/auth/test_error_handling.py | 16 +- .../fastmcp/auth/test_auth_integration.py | 124 +++--------- tests/server/fastmcp/prompts/test_base.py | 34 +--- tests/server/fastmcp/prompts/test_manager.py | 8 +- .../fastmcp/resources/test_file_resources.py | 4 +- tests/server/fastmcp/test_func_metadata.py | 12 +- tests/server/fastmcp/test_integration.py | 124 +++--------- tests/server/fastmcp/test_server.py | 47 ++--- tests/server/fastmcp/test_tool_manager.py | 8 +- .../server/test_lowlevel_tool_annotations.py | 20 +- tests/server/test_read_resource.py | 6 +- tests/server/test_session.py | 32 +--- tests/server/test_stdio.py | 27 +-- tests/server/test_streamable_http_manager.py | 14 +- tests/shared/test_progress_notifications.py | 28 +-- tests/shared/test_session.py | 8 +- tests/shared/test_sse.py | 91 ++------- tests/shared/test_streamable_http.py | 141 ++++---------- tests/shared/test_ws.py | 36 +--- tests/test_examples.py | 8 +- 89 files changed, 666 insertions(+), 2108 deletions(-) diff --git a/examples/fastmcp/memory.py b/examples/fastmcp/memory.py index dbc890815..0f97babf1 100644 --- a/examples/fastmcp/memory.py +++ b/examples/fastmcp/memory.py @@ -47,18 +47,14 @@ DB_DSN = "postgresql://postgres:postgres@localhost:54320/memory_db" # reset memory with rm ~/.fastmcp/{USER}/memory/* -PROFILE_DIR = ( - Path.home() / ".fastmcp" / os.environ.get("USER", "anon") / "memory" -).resolve() +PROFILE_DIR = (Path.home() / ".fastmcp" / os.environ.get("USER", "anon") / "memory").resolve() PROFILE_DIR.mkdir(parents=True, exist_ok=True) def cosine_similarity(a: list[float], b: list[float]) -> float: a_array = np.array(a, dtype=np.float64) b_array = np.array(b, dtype=np.float64) - return np.dot(a_array, b_array) / ( - np.linalg.norm(a_array) * np.linalg.norm(b_array) - ) + return np.dot(a_array, b_array) / (np.linalg.norm(a_array) * np.linalg.norm(b_array)) async def do_ai[T]( @@ -97,9 +93,7 @@ class MemoryNode(BaseModel): summary: str = "" importance: float = 1.0 access_count: int = 0 - timestamp: float = Field( - default_factory=lambda: datetime.now(timezone.utc).timestamp() - ) + timestamp: float = Field(default_factory=lambda: datetime.now(timezone.utc).timestamp()) embedding: list[float] @classmethod @@ -152,9 +146,7 @@ async def merge_with(self, other: Self, deps: Deps): self.importance += other.importance self.access_count += other.access_count self.embedding = [(a + b) / 2 for a, b in zip(self.embedding, other.embedding)] - self.summary = await do_ai( - self.content, "Summarize the following text concisely.", str, deps - ) + self.summary = await do_ai(self.content, "Summarize the following text concisely.", str, deps) await self.save(deps) # Delete the merged node from the database if other.id is not None: @@ -221,9 +213,7 @@ async def find_similar_memories(embedding: list[float], deps: Deps) -> list[Memo async def update_importance(user_embedding: list[float], deps: Deps): async with deps.pool.acquire() as conn: - rows = await conn.fetch( - "SELECT id, importance, access_count, embedding FROM memories" - ) + rows = await conn.fetch("SELECT id, importance, access_count, embedding FROM memories") for row in rows: memory_embedding = row["embedding"] similarity = cosine_similarity(user_embedding, memory_embedding) @@ -273,9 +263,7 @@ async def display_memory_tree(deps: Deps) -> str: ) result = "" for row in rows: - effective_importance = row["importance"] * ( - 1 + math.log(row["access_count"] + 1) - ) + effective_importance = row["importance"] * (1 + math.log(row["access_count"] + 1)) summary = row["summary"] or row["content"] result += f"- {summary} (Importance: {effective_importance:.2f})\n" return result @@ -283,15 +271,11 @@ async def display_memory_tree(deps: Deps) -> str: @mcp.tool() async def remember( - contents: list[str] = Field( - description="List of observations or memories to store" - ), + contents: list[str] = Field(description="List of observations or memories to store"), ): deps = Deps(openai=AsyncOpenAI(), pool=await get_db_pool()) try: - return "\n".join( - await asyncio.gather(*[add_memory(content, deps) for content in contents]) - ) + return "\n".join(await asyncio.gather(*[add_memory(content, deps) for content in contents])) finally: await deps.pool.close() @@ -305,9 +289,7 @@ async def read_profile() -> str: async def initialize_database(): - pool = await asyncpg.create_pool( - "postgresql://postgres:postgres@localhost:54320/postgres" - ) + pool = await asyncpg.create_pool("postgresql://postgres:postgres@localhost:54320/postgres") try: async with pool.acquire() as conn: await conn.execute(""" diff --git a/examples/fastmcp/text_me.py b/examples/fastmcp/text_me.py index 8053c6cc5..2434dcddd 100644 --- a/examples/fastmcp/text_me.py +++ b/examples/fastmcp/text_me.py @@ -28,15 +28,11 @@ class SurgeSettings(BaseSettings): - model_config: SettingsConfigDict = SettingsConfigDict( - env_prefix="SURGE_", env_file=".env" - ) + model_config: SettingsConfigDict = SettingsConfigDict(env_prefix="SURGE_", env_file=".env") api_key: str account_id: str - my_phone_number: Annotated[ - str, BeforeValidator(lambda v: "+" + v if not v.startswith("+") else v) - ] + my_phone_number: Annotated[str, BeforeValidator(lambda v: "+" + v if not v.startswith("+") else v)] my_first_name: str my_last_name: str diff --git a/examples/fastmcp/unicode_example.py b/examples/fastmcp/unicode_example.py index a69f586a5..94ef628bb 100644 --- a/examples/fastmcp/unicode_example.py +++ b/examples/fastmcp/unicode_example.py @@ -8,10 +8,7 @@ mcp = FastMCP() -@mcp.tool( - description="🌟 A tool that uses various Unicode characters in its description: " - "á é í ó ú ñ 漢字 🎉" -) +@mcp.tool(description="🌟 A tool that uses various Unicode characters in its description: " "á é í ó ú ñ 漢字 🎉") def hello_unicode(name: str = "世界", greeting: str = "¡Hola") -> str: """ A simple tool that demonstrates Unicode handling in: diff --git a/examples/servers/simple-auth/mcp_simple_auth/server.py b/examples/servers/simple-auth/mcp_simple_auth/server.py index 51f449113..8b25dfd44 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/server.py +++ b/examples/servers/simple-auth/mcp_simple_auth/server.py @@ -82,9 +82,7 @@ async def register_client(self, client_info: OAuthClientInformationFull): """Register a new OAuth client.""" self.clients[client_info.client_id] = client_info - async def authorize( - self, client: OAuthClientInformationFull, params: AuthorizationParams - ) -> str: + async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str: """Generate an authorization URL for GitHub OAuth flow.""" state = params.state or secrets.token_hex(16) @@ -92,9 +90,7 @@ async def authorize( self.state_mapping[state] = { "redirect_uri": str(params.redirect_uri), "code_challenge": params.code_challenge, - "redirect_uri_provided_explicitly": str( - params.redirect_uri_provided_explicitly - ), + "redirect_uri_provided_explicitly": str(params.redirect_uri_provided_explicitly), "client_id": client.client_id, } @@ -117,9 +113,7 @@ async def handle_github_callback(self, code: str, state: str) -> str: redirect_uri = state_data["redirect_uri"] code_challenge = state_data["code_challenge"] - redirect_uri_provided_explicitly = ( - state_data["redirect_uri_provided_explicitly"] == "True" - ) + redirect_uri_provided_explicitly = state_data["redirect_uri_provided_explicitly"] == "True" client_id = state_data["client_id"] # Exchange code for token with GitHub @@ -200,8 +194,7 @@ async def exchange_authorization_code( for token, data in self.tokens.items() # see https://github.blog/engineering/platform-security/behind-githubs-new-authentication-token-formats/ # which you get depends on your GH app setup. - if (token.startswith("ghu_") or token.startswith("gho_")) - and data.client_id == client.client_id + if (token.startswith("ghu_") or token.startswith("gho_")) and data.client_id == client.client_id ), None, ) @@ -232,9 +225,7 @@ async def load_access_token(self, token: str) -> AccessToken | None: return access_token - async def load_refresh_token( - self, client: OAuthClientInformationFull, refresh_token: str - ) -> RefreshToken | None: + async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_token: str) -> RefreshToken | None: """Load a refresh token - not supported.""" return None @@ -247,9 +238,7 @@ async def exchange_refresh_token( """Exchange refresh token""" raise NotImplementedError("Not supported") - async def revoke_token( - self, token: str, token_type_hint: str | None = None - ) -> None: + async def revoke_token(self, token: str, token_type_hint: str | None = None) -> None: """Revoke a token.""" if token in self.tokens: del self.tokens[token] @@ -335,9 +324,7 @@ async def get_user_profile() -> dict[str, Any]: ) if response.status_code != 200: - raise ValueError( - f"GitHub API error: {response.status_code} - {response.text}" - ) + raise ValueError(f"GitHub API error: {response.status_code} - {response.text}") return response.json() @@ -361,9 +348,7 @@ def main(port: int, host: str, transport: Literal["sse", "streamable-http"]) -> # No hardcoded credentials - all from environment variables settings = ServerSettings(host=host, port=port) except ValueError as e: - logger.error( - "Failed to load settings. Make sure environment variables are set:" - ) + logger.error("Failed to load settings. Make sure environment variables are set:") logger.error(" MCP_GITHUB_GITHUB_CLIENT_ID=") logger.error(" MCP_GITHUB_GITHUB_CLIENT_SECRET=") logger.error(f"Error: {e}") diff --git a/pyproject.toml b/pyproject.toml index 0a11a3b15..56a7d385a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,17 +86,17 @@ Issues = "https://github.com/modelcontextprotocol/python-sdk/issues" packages = ["src/mcp"] [tool.pyright] +typeCheckingMode = "strict" include = ["src/mcp", "tests", "examples/servers"] venvPath = "." venv = ".venv" -strict = ["src/mcp/**/*.py"] [tool.ruff.lint] select = ["C4", "E", "F", "I", "PERF", "UP"] ignore = ["PERF203"] [tool.ruff] -line-length = 88 +line-length = 120 target-version = "py310" [tool.ruff.lint.per-file-ignores] diff --git a/src/mcp/cli/claude.py b/src/mcp/cli/claude.py index 1629f9287..e6eab2851 100644 --- a/src/mcp/cli/claude.py +++ b/src/mcp/cli/claude.py @@ -21,9 +21,7 @@ def get_claude_config_path() -> Path | None: elif sys.platform == "darwin": path = Path(Path.home(), "Library", "Application Support", "Claude") elif sys.platform.startswith("linux"): - path = Path( - os.environ.get("XDG_CONFIG_HOME", Path.home() / ".config"), "Claude" - ) + path = Path(os.environ.get("XDG_CONFIG_HOME", Path.home() / ".config"), "Claude") else: return None @@ -37,8 +35,7 @@ def get_uv_path() -> str: uv_path = shutil.which("uv") if not uv_path: logger.error( - "uv executable not found in PATH, falling back to 'uv'. " - "Please ensure uv is installed and in your PATH" + "uv executable not found in PATH, falling back to 'uv'. " "Please ensure uv is installed and in your PATH" ) return "uv" # Fall back to just "uv" if not found return uv_path @@ -94,10 +91,7 @@ def update_claude_config( config["mcpServers"] = {} # Always preserve existing env vars and merge with new ones - if ( - server_name in config["mcpServers"] - and "env" in config["mcpServers"][server_name] - ): + if server_name in config["mcpServers"] and "env" in config["mcpServers"][server_name]: existing_env = config["mcpServers"][server_name]["env"] if env_vars: # New vars take precedence over existing ones diff --git a/src/mcp/cli/cli.py b/src/mcp/cli/cli.py index b2632f1d9..69e2921f1 100644 --- a/src/mcp/cli/cli.py +++ b/src/mcp/cli/cli.py @@ -45,9 +45,7 @@ def _get_npx_command(): # Try both npx.cmd and npx.exe on Windows for cmd in ["npx.cmd", "npx.exe", "npx"]: try: - subprocess.run( - [cmd, "--version"], check=True, capture_output=True, shell=True - ) + subprocess.run([cmd, "--version"], check=True, capture_output=True, shell=True) return cmd except subprocess.CalledProcessError: continue @@ -58,9 +56,7 @@ def _get_npx_command(): def _parse_env_var(env_var: str) -> tuple[str, str]: """Parse environment variable string in format KEY=VALUE.""" if "=" not in env_var: - logger.error( - f"Invalid environment variable format: {env_var}. Must be KEY=VALUE" - ) + logger.error(f"Invalid environment variable format: {env_var}. Must be KEY=VALUE") sys.exit(1) key, value = env_var.split("=", 1) return key.strip(), value.strip() @@ -154,14 +150,10 @@ def _check_server_object(server_object: Any, object_name: str): True if it's supported. """ if not isinstance(server_object, FastMCP): - logger.error( - f"The server object {object_name} is of type " - f"{type(server_object)} (expecting {FastMCP})." - ) + logger.error(f"The server object {object_name} is of type " f"{type(server_object)} (expecting {FastMCP}).") if isinstance(server_object, LowLevelServer): logger.warning( - "Note that only FastMCP server is supported. Low level " - "Server class is not yet supported." + "Note that only FastMCP server is supported. Low level " "Server class is not yet supported." ) return False return True @@ -172,10 +164,7 @@ def _check_server_object(server_object: Any, object_name: str): for name in ["mcp", "server", "app"]: if hasattr(module, name): if not _check_server_object(getattr(module, name), f"{file}:{name}"): - logger.error( - f"Ignoring object '{file}:{name}' as it's not a valid " - "server object" - ) + logger.error(f"Ignoring object '{file}:{name}' as it's not a valid " "server object") continue return getattr(module, name) @@ -280,8 +269,7 @@ def dev( npx_cmd = _get_npx_command() if not npx_cmd: logger.error( - "npx not found. Please ensure Node.js and npm are properly installed " - "and added to your system PATH." + "npx not found. Please ensure Node.js and npm are properly installed " "and added to your system PATH." ) sys.exit(1) @@ -383,8 +371,7 @@ def install( typer.Option( "--name", "-n", - help="Custom name for the server (defaults to server's name attribute or" - " file name)", + help="Custom name for the server (defaults to server's name attribute or" " file name)", ), ] = None, with_editable: Annotated[ @@ -458,8 +445,7 @@ def install( name = server.name except (ImportError, ModuleNotFoundError) as e: logger.debug( - "Could not import server (likely missing dependencies), using file" - " name", + "Could not import server (likely missing dependencies), using file" " name", extra={"error": str(e)}, ) name = file.stem @@ -477,11 +463,7 @@ def install( if env_file: if dotenv: try: - env_dict |= { - k: v - for k, v in dotenv.dotenv_values(env_file).items() - if v is not None - } + env_dict |= {k: v for k, v in dotenv.dotenv_values(env_file).items() if v is not None} except Exception as e: logger.error(f"Failed to load .env file: {e}") sys.exit(1) diff --git a/src/mcp/client/__main__.py b/src/mcp/client/__main__.py index 2ec68e56c..2efe05d53 100644 --- a/src/mcp/client/__main__.py +++ b/src/mcp/client/__main__.py @@ -24,9 +24,7 @@ async def message_handler( - message: RequestResponder[types.ServerRequest, types.ClientResult] - | types.ServerNotification - | Exception, + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: if isinstance(message, Exception): logger.error("Error: %s", message) @@ -60,9 +58,7 @@ async def main(command_or_url: str, args: list[str], env: list[tuple[str, str]]) await run_session(*streams) else: # Use stdio client for commands - server_parameters = StdioServerParameters( - command=command_or_url, args=args, env=env_dict - ) + server_parameters = StdioServerParameters(command=command_or_url, args=args, env=env_dict) async with stdio_client(server_parameters) as streams: await run_session(*streams) diff --git a/src/mcp/client/auth.py b/src/mcp/client/auth.py index fc6c96a43..7782022ce 100644 --- a/src/mcp/client/auth.py +++ b/src/mcp/client/auth.py @@ -17,12 +17,7 @@ import anyio import httpx -from mcp.shared.auth import ( - OAuthClientInformationFull, - OAuthClientMetadata, - OAuthMetadata, - OAuthToken, -) +from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthMetadata, OAuthToken from mcp.types import LATEST_PROTOCOL_VERSION logger = logging.getLogger(__name__) @@ -100,10 +95,7 @@ def __init__( def _generate_code_verifier(self) -> str: """Generate a cryptographically random code verifier for PKCE.""" - return "".join( - secrets.choice(string.ascii_letters + string.digits + "-._~") - for _ in range(128) - ) + return "".join(secrets.choice(string.ascii_letters + string.digits + "-._~") for _ in range(128)) def _generate_code_challenge(self, code_verifier: str) -> str: """Generate a code challenge from a code verifier using SHA256.""" @@ -148,9 +140,7 @@ async def _discover_oauth_metadata(self, server_url: str) -> OAuthMetadata | Non return None response.raise_for_status() metadata_json = response.json() - logger.debug( - f"OAuth metadata discovered (no MCP header): {metadata_json}" - ) + logger.debug(f"OAuth metadata discovered (no MCP header): {metadata_json}") return OAuthMetadata.model_validate(metadata_json) except Exception: logger.exception("Failed to discover OAuth metadata") @@ -176,17 +166,11 @@ async def _register_oauth_client( registration_url = urljoin(auth_base_url, "/register") # Handle default scope - if ( - client_metadata.scope is None - and metadata - and metadata.scopes_supported is not None - ): + if client_metadata.scope is None and metadata and metadata.scopes_supported is not None: client_metadata.scope = " ".join(metadata.scopes_supported) # Serialize client metadata - registration_data = client_metadata.model_dump( - by_alias=True, mode="json", exclude_none=True - ) + registration_data = client_metadata.model_dump(by_alias=True, mode="json", exclude_none=True) async with httpx.AsyncClient() as client: try: @@ -213,9 +197,7 @@ async def _register_oauth_client( logger.exception("Registration error") raise - async def async_auth_flow( - self, request: httpx.Request - ) -> AsyncGenerator[httpx.Request, httpx.Response]: + async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]: """ HTTPX auth flow integration. """ @@ -225,9 +207,7 @@ async def async_auth_flow( await self.ensure_token() # Add Bearer token if available if self._current_tokens and self._current_tokens.access_token: - request.headers["Authorization"] = ( - f"Bearer {self._current_tokens.access_token}" - ) + request.headers["Authorization"] = f"Bearer {self._current_tokens.access_token}" response = yield request @@ -305,11 +285,7 @@ async def ensure_token(self) -> None: return # Try refreshing existing token - if ( - self._current_tokens - and self._current_tokens.refresh_token - and await self._refresh_access_token() - ): + if self._current_tokens and self._current_tokens.refresh_token and await self._refresh_access_token(): return # Fall back to full OAuth flow @@ -361,12 +337,8 @@ async def _perform_oauth_flow(self) -> None: auth_code, returned_state = await self.callback_handler() # Validate state parameter for CSRF protection - if returned_state is None or not secrets.compare_digest( - returned_state, self._auth_state - ): - raise Exception( - f"State parameter mismatch: {returned_state} != {self._auth_state}" - ) + if returned_state is None or not secrets.compare_digest(returned_state, self._auth_state): + raise Exception(f"State parameter mismatch: {returned_state} != {self._auth_state}") # Clear state after validation self._auth_state = None @@ -377,9 +349,7 @@ async def _perform_oauth_flow(self) -> None: # Exchange authorization code for tokens await self._exchange_code_for_token(auth_code, client_info) - async def _exchange_code_for_token( - self, auth_code: str, client_info: OAuthClientInformationFull - ) -> None: + async def _exchange_code_for_token(self, auth_code: str, client_info: OAuthClientInformationFull) -> None: """Exchange authorization code for access token.""" # Get token endpoint if self._metadata and self._metadata.token_endpoint: @@ -412,17 +382,10 @@ async def _exchange_code_for_token( # Parse OAuth error response try: error_data = response.json() - error_msg = error_data.get( - "error_description", error_data.get("error", "Unknown error") - ) - raise Exception( - f"Token exchange failed: {error_msg} " - f"(HTTP {response.status_code})" - ) + error_msg = error_data.get("error_description", error_data.get("error", "Unknown error")) + raise Exception(f"Token exchange failed: {error_msg} " f"(HTTP {response.status_code})") except Exception: - raise Exception( - f"Token exchange failed: {response.status_code} {response.text}" - ) + raise Exception(f"Token exchange failed: {response.status_code} {response.text}") # Parse token response token_response = OAuthToken.model_validate(response.json()) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 3b7fc3fae..02a4e8e01 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -38,16 +38,12 @@ async def __call__( class MessageHandlerFnT(Protocol): async def __call__( self, - message: RequestResponder[types.ServerRequest, types.ClientResult] - | types.ServerNotification - | Exception, + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: ... async def _default_message_handler( - message: RequestResponder[types.ServerRequest, types.ClientResult] - | types.ServerNotification - | Exception, + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: await anyio.lowlevel.checkpoint() @@ -77,9 +73,7 @@ async def _default_logging_callback( pass -ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter( - types.ClientResult | types.ErrorData -) +ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(types.ClientResult | types.ErrorData) class ClientSession( @@ -116,11 +110,7 @@ def __init__( self._message_handler = message_handler or _default_message_handler async def initialize(self) -> types.InitializeResult: - sampling = ( - types.SamplingCapability() - if self._sampling_callback is not _default_sampling_callback - else None - ) + sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None roots = ( # TODO: Should this be based on whether we # _will_ send notifications, or only whether @@ -149,15 +139,10 @@ async def initialize(self) -> types.InitializeResult: ) if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS: - raise RuntimeError( - "Unsupported protocol version from the server: " - f"{result.protocolVersion}" - ) + raise RuntimeError("Unsupported protocol version from the server: " f"{result.protocolVersion}") await self.send_notification( - types.ClientNotification( - types.InitializedNotification(method="notifications/initialized") - ) + types.ClientNotification(types.InitializedNotification(method="notifications/initialized")) ) return result @@ -207,33 +192,25 @@ async def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResul types.EmptyResult, ) - async def list_resources( - self, cursor: str | None = None - ) -> types.ListResourcesResult: + async def list_resources(self, cursor: str | None = None) -> types.ListResourcesResult: """Send a resources/list request.""" return await self.send_request( types.ClientRequest( types.ListResourcesRequest( method="resources/list", - params=types.PaginatedRequestParams(cursor=cursor) - if cursor is not None - else None, + params=types.PaginatedRequestParams(cursor=cursor) if cursor is not None else None, ) ), types.ListResourcesResult, ) - async def list_resource_templates( - self, cursor: str | None = None - ) -> types.ListResourceTemplatesResult: + async def list_resource_templates(self, cursor: str | None = None) -> types.ListResourceTemplatesResult: """Send a resources/templates/list request.""" return await self.send_request( types.ClientRequest( types.ListResourceTemplatesRequest( method="resources/templates/list", - params=types.PaginatedRequestParams(cursor=cursor) - if cursor is not None - else None, + params=types.PaginatedRequestParams(cursor=cursor) if cursor is not None else None, ) ), types.ListResourceTemplatesResult, @@ -305,17 +282,13 @@ async def list_prompts(self, cursor: str | None = None) -> types.ListPromptsResu types.ClientRequest( types.ListPromptsRequest( method="prompts/list", - params=types.PaginatedRequestParams(cursor=cursor) - if cursor is not None - else None, + params=types.PaginatedRequestParams(cursor=cursor) if cursor is not None else None, ) ), types.ListPromptsResult, ) - async def get_prompt( - self, name: str, arguments: dict[str, str] | None = None - ) -> types.GetPromptResult: + async def get_prompt(self, name: str, arguments: dict[str, str] | None = None) -> types.GetPromptResult: """Send a prompts/get request.""" return await self.send_request( types.ClientRequest( @@ -352,9 +325,7 @@ async def list_tools(self, cursor: str | None = None) -> types.ListToolsResult: types.ClientRequest( types.ListToolsRequest( method="tools/list", - params=types.PaginatedRequestParams(cursor=cursor) - if cursor is not None - else None, + params=types.PaginatedRequestParams(cursor=cursor) if cursor is not None else None, ) ), types.ListToolsResult, @@ -370,9 +341,7 @@ async def send_roots_list_changed(self) -> None: ) ) - async def _received_request( - self, responder: RequestResponder[types.ServerRequest, types.ClientResult] - ) -> None: + async def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None: ctx = RequestContext[ClientSession, Any]( request_id=responder.request_id, meta=responder.request_meta, @@ -395,22 +364,16 @@ async def _received_request( case types.PingRequest(): with responder: - return await responder.respond( - types.ClientResult(root=types.EmptyResult()) - ) + return await responder.respond(types.ClientResult(root=types.EmptyResult())) async def _handle_incoming( self, - req: RequestResponder[types.ServerRequest, types.ClientResult] - | types.ServerNotification - | Exception, + req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: """Handle incoming messages by forwarding to the message handler.""" await self._message_handler(req) - async def _received_notification( - self, notification: types.ServerNotification - ) -> None: + async def _received_notification(self, notification: types.ServerNotification) -> None: """Handle notifications from the server.""" # Process specific notification types match notification.root: diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index a77dc7a1e..700b5417f 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -62,9 +62,7 @@ class StreamableHttpParameters(BaseModel): terminate_on_close: bool = True -ServerParameters: TypeAlias = ( - StdioServerParameters | SseServerParameters | StreamableHttpParameters -) +ServerParameters: TypeAlias = StdioServerParameters | SseServerParameters | StreamableHttpParameters class ClientSessionGroup: @@ -261,9 +259,7 @@ async def _establish_session( ) read, write, _ = await session_stack.enter_async_context(client) - session = await session_stack.enter_async_context( - mcp.ClientSession(read, write) - ) + session = await session_stack.enter_async_context(mcp.ClientSession(read, write)) result = await session.initialize() # Session successfully initialized. @@ -280,9 +276,7 @@ async def _establish_session( await session_stack.aclose() raise - async def _aggregate_components( - self, server_info: types.Implementation, session: mcp.ClientSession - ) -> None: + async def _aggregate_components(self, server_info: types.Implementation, session: mcp.ClientSession) -> None: """Aggregates prompts, resources, and tools from a given session.""" # Create a reverse index so we can find all prompts, resources, and diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 2013e4199..1ce6a6da4 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -73,20 +73,16 @@ async def sse_reader( match sse.event: case "endpoint": endpoint_url = urljoin(url, sse.data) - logger.debug( - f"Received endpoint URL: {endpoint_url}" - ) + logger.debug(f"Received endpoint URL: {endpoint_url}") url_parsed = urlparse(url) endpoint_parsed = urlparse(endpoint_url) if ( url_parsed.netloc != endpoint_parsed.netloc - or url_parsed.scheme - != endpoint_parsed.scheme + or url_parsed.scheme != endpoint_parsed.scheme ): error_msg = ( - "Endpoint origin does not match " - f"connection origin: {endpoint_url}" + "Endpoint origin does not match " f"connection origin: {endpoint_url}" ) logger.error(error_msg) raise ValueError(error_msg) @@ -98,22 +94,16 @@ async def sse_reader( message = types.JSONRPCMessage.model_validate_json( # noqa: E501 sse.data ) - logger.debug( - f"Received server message: {message}" - ) + logger.debug(f"Received server message: {message}") except Exception as exc: - logger.error( - f"Error parsing server message: {exc}" - ) + logger.error(f"Error parsing server message: {exc}") await read_stream_writer.send(exc) continue session_message = SessionMessage(message) await read_stream_writer.send(session_message) case _: - logger.warning( - f"Unknown SSE event: {sse.event}" - ) + logger.warning(f"Unknown SSE event: {sse.event}") except Exception as exc: logger.error(f"Error in sse_reader: {exc}") await read_stream_writer.send(exc) @@ -124,9 +114,7 @@ async def post_writer(endpoint_url: str): try: async with write_stream_reader: async for session_message in write_stream_reader: - logger.debug( - f"Sending client message: {session_message}" - ) + logger.debug(f"Sending client message: {session_message}") response = await client.post( endpoint_url, json=session_message.message.model_dump( @@ -136,19 +124,14 @@ async def post_writer(endpoint_url: str): ), ) response.raise_for_status() - logger.debug( - "Client message sent successfully: " - f"{response.status_code}" - ) + logger.debug("Client message sent successfully: " f"{response.status_code}") except Exception as exc: logger.error(f"Error in post_writer: {exc}") finally: await write_stream.aclose() endpoint_url = await tg.start(sse_reader) - logger.debug( - f"Starting post writer with endpoint URL: {endpoint_url}" - ) + logger.debug(f"Starting post writer with endpoint URL: {endpoint_url}") tg.start_soon(post_writer, endpoint_url) try: diff --git a/src/mcp/client/stdio/__init__.py b/src/mcp/client/stdio/__init__.py index fce605633..a75cfd764 100644 --- a/src/mcp/client/stdio/__init__.py +++ b/src/mcp/client/stdio/__init__.py @@ -115,11 +115,7 @@ async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stder process = await _create_platform_compatible_process( command=command, args=server.args, - env=( - {**get_default_environment(), **server.env} - if server.env is not None - else get_default_environment() - ), + env=({**get_default_environment(), **server.env} if server.env is not None else get_default_environment()), errlog=errlog, cwd=server.cwd, ) @@ -163,9 +159,7 @@ async def stdin_writer(): try: async with write_stream_reader: async for session_message in write_stream_reader: - json = session_message.message.model_dump_json( - by_alias=True, exclude_none=True - ) + json = session_message.message.model_dump_json(by_alias=True, exclude_none=True) await process.stdin.send( (json + "\n").encode( encoding=server.encoding, @@ -229,8 +223,6 @@ async def _create_platform_compatible_process( if sys.platform == "win32": process = await create_windows_process(command, args, env, errlog, cwd) else: - process = await anyio.open_process( - [command, *args], env=env, stderr=errlog, cwd=cwd - ) + process = await anyio.open_process([command, *args], env=env, stderr=errlog, cwd=cwd) return process diff --git a/src/mcp/client/stdio/win32.py b/src/mcp/client/stdio/win32.py index 825a0477d..e4f252dc9 100644 --- a/src/mcp/client/stdio/win32.py +++ b/src/mcp/client/stdio/win32.py @@ -82,9 +82,7 @@ async def create_windows_process( return process except Exception: # Don't raise, let's try to create the process without creation flags - process = await anyio.open_process( - [command, *args], env=env, stderr=errlog, cwd=cwd - ) + process = await anyio.open_process([command, *args], env=env, stderr=errlog, cwd=cwd) return process diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 2855f606d..cef98d833 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -106,9 +106,7 @@ def __init__( **self.headers, } - def _update_headers_with_session( - self, base_headers: dict[str, str] - ) -> dict[str, str]: + def _update_headers_with_session(self, base_headers: dict[str, str]) -> dict[str, str]: """Update headers with session ID if available.""" headers = base_headers.copy() if self.session_id: @@ -117,17 +115,11 @@ def _update_headers_with_session( def _is_initialization_request(self, message: JSONRPCMessage) -> bool: """Check if the message is an initialization request.""" - return ( - isinstance(message.root, JSONRPCRequest) - and message.root.method == "initialize" - ) + return isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize" def _is_initialized_notification(self, message: JSONRPCMessage) -> bool: """Check if the message is an initialized notification.""" - return ( - isinstance(message.root, JSONRPCNotification) - and message.root.method == "notifications/initialized" - ) + return isinstance(message.root, JSONRPCNotification) and message.root.method == "notifications/initialized" def _maybe_extract_session_id_from_response( self, @@ -153,9 +145,7 @@ async def _handle_sse_event( logger.debug(f"SSE message: {message}") # If this is a response and we have original_request_id, replace it - if original_request_id is not None and isinstance( - message.root, JSONRPCResponse | JSONRPCError - ): + if original_request_id is not None and isinstance(message.root, JSONRPCResponse | JSONRPCError): message.root.id = original_request_id session_message = SessionMessage(message) @@ -194,9 +184,7 @@ async def handle_get_stream( "GET", self.url, headers=headers, - timeout=httpx.Timeout( - self.timeout.seconds, read=self.sse_read_timeout.seconds - ), + timeout=httpx.Timeout(self.timeout.seconds, read=self.sse_read_timeout.seconds), ) as event_source: event_source.response.raise_for_status() logger.debug("GET SSE connection established") @@ -225,9 +213,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None: "GET", self.url, headers=headers, - timeout=httpx.Timeout( - self.timeout.seconds, read=ctx.sse_read_timeout.seconds - ), + timeout=httpx.Timeout(self.timeout.seconds, read=ctx.sse_read_timeout.seconds), ) as event_source: event_source.response.raise_for_status() logger.debug("Resumption GET SSE connection established") @@ -297,9 +283,7 @@ async def _handle_json_response( logger.error(f"Error parsing JSON response: {exc}") await read_stream_writer.send(exc) - async def _handle_sse_response( - self, response: httpx.Response, ctx: RequestContext - ) -> None: + async def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext) -> None: """Handle SSE response from the server.""" try: event_source = EventSource(response) @@ -307,11 +291,7 @@ async def _handle_sse_response( is_complete = await self._handle_sse_event( sse, ctx.read_stream_writer, - resumption_callback=( - ctx.metadata.on_resumption_token_update - if ctx.metadata - else None - ), + resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None), ) # If the SSE event indicates completion, like returning respose/error # break the loop @@ -454,12 +434,8 @@ async def streamablehttp_client( """ transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout, auth) - read_stream_writer, read_stream = anyio.create_memory_object_stream[ - SessionMessage | Exception - ](0) - write_stream, write_stream_reader = anyio.create_memory_object_stream[ - SessionMessage - ](0) + read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0) + write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0) async with anyio.create_task_group() as tg: try: @@ -467,16 +443,12 @@ async def streamablehttp_client( async with httpx_client_factory( headers=transport.request_headers, - timeout=httpx.Timeout( - transport.timeout.seconds, read=transport.sse_read_timeout.seconds - ), + timeout=httpx.Timeout(transport.timeout.seconds, read=transport.sse_read_timeout.seconds), auth=transport.auth, ) as client: # Define callbacks that need access to tg def start_get_stream() -> None: - tg.start_soon( - transport.handle_get_stream, client, read_stream_writer - ) + tg.start_soon(transport.handle_get_stream, client, read_stream_writer) tg.start_soon( transport.post_writer, diff --git a/src/mcp/client/websocket.py b/src/mcp/client/websocket.py index ac542fb3f..0a371610b 100644 --- a/src/mcp/client/websocket.py +++ b/src/mcp/client/websocket.py @@ -19,10 +19,7 @@ async def websocket_client( url: str, ) -> AsyncGenerator[ - tuple[ - MemoryObjectReceiveStream[SessionMessage | Exception], - MemoryObjectSendStream[SessionMessage], - ], + tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]], None, ]: """ @@ -74,9 +71,7 @@ async def ws_writer(): async with write_stream_reader: async for session_message in write_stream_reader: # Convert to a dict, then to JSON - msg_dict = session_message.message.model_dump( - by_alias=True, mode="json", exclude_none=True - ) + msg_dict = session_message.message.model_dump(by_alias=True, mode="json", exclude_none=True) await ws.send(json.dumps(msg_dict)) async with anyio.create_task_group() as tg: diff --git a/src/mcp/server/auth/errors.py b/src/mcp/server/auth/errors.py index 053c2fd2e..117deea83 100644 --- a/src/mcp/server/auth/errors.py +++ b/src/mcp/server/auth/errors.py @@ -2,7 +2,4 @@ def stringify_pydantic_error(validation_error: ValidationError) -> str: - return "\n".join( - f"{'.'.join(str(loc) for loc in e['loc'])}: {e['msg']}" - for e in validation_error.errors() - ) + return "\n".join(f"{'.'.join(str(loc) for loc in e['loc'])}: {e['msg']}" for e in validation_error.errors()) diff --git a/src/mcp/server/auth/handlers/authorize.py b/src/mcp/server/auth/handlers/authorize.py index 8f3768908..f29a3509a 100644 --- a/src/mcp/server/auth/handlers/authorize.py +++ b/src/mcp/server/auth/handlers/authorize.py @@ -7,9 +7,7 @@ from starlette.requests import Request from starlette.responses import RedirectResponse, Response -from mcp.server.auth.errors import ( - stringify_pydantic_error, -) +from mcp.server.auth.errors import stringify_pydantic_error from mcp.server.auth.json_response import PydanticJSONResponse from mcp.server.auth.provider import ( AuthorizationErrorCode, @@ -18,10 +16,7 @@ OAuthAuthorizationServerProvider, construct_redirect_uri, ) -from mcp.shared.auth import ( - InvalidRedirectUriError, - InvalidScopeError, -) +from mcp.shared.auth import InvalidRedirectUriError, InvalidScopeError logger = logging.getLogger(__name__) @@ -29,23 +24,16 @@ class AuthorizationRequest(BaseModel): # See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.1 client_id: str = Field(..., description="The client ID") - redirect_uri: AnyHttpUrl | None = Field( - None, description="URL to redirect to after authorization" - ) + redirect_uri: AnyHttpUrl | None = Field(None, description="URL to redirect to after authorization") # see OAuthClientMetadata; we only support `code` - response_type: Literal["code"] = Field( - ..., description="Must be 'code' for authorization code flow" - ) + response_type: Literal["code"] = Field(..., description="Must be 'code' for authorization code flow") code_challenge: str = Field(..., description="PKCE code challenge") - code_challenge_method: Literal["S256"] = Field( - "S256", description="PKCE code challenge method, must be S256" - ) + code_challenge_method: Literal["S256"] = Field("S256", description="PKCE code challenge method, must be S256") state: str | None = Field(None, description="Optional state parameter") scope: str | None = Field( None, - description="Optional scope; if specified, should be " - "a space-separated list of scope strings", + description="Optional scope; if specified, should be " "a space-separated list of scope strings", ) @@ -57,9 +45,7 @@ class AuthorizationErrorResponse(BaseModel): state: str | None = None -def best_effort_extract_string( - key: str, params: None | FormData | QueryParams -) -> str | None: +def best_effort_extract_string(key: str, params: None | FormData | QueryParams) -> str | None: if params is None: return None value = params.get(key) @@ -138,9 +124,7 @@ async def error_response( if redirect_uri and client: return RedirectResponse( - url=construct_redirect_uri( - str(redirect_uri), **error_resp.model_dump(exclude_none=True) - ), + url=construct_redirect_uri(str(redirect_uri), **error_resp.model_dump(exclude_none=True)), status_code=302, headers={"Cache-Control": "no-store"}, ) @@ -172,9 +156,7 @@ async def error_response( if e["loc"] == ("response_type",) and e["type"] == "literal_error": error = "unsupported_response_type" break - return await error_response( - error, stringify_pydantic_error(validation_error) - ) + return await error_response(error, stringify_pydantic_error(validation_error)) # Get client information client = await self.provider.get_client( @@ -229,16 +211,9 @@ async def error_response( ) except AuthorizeError as e: # Handle authorization errors as defined in RFC 6749 Section 4.1.2.1 - return await error_response( - error=e.error, - error_description=e.error_description, - ) + return await error_response(error=e.error, error_description=e.error_description) except Exception as validation_error: # Catch-all for unexpected errors - logger.exception( - "Unexpected error in authorization_handler", exc_info=validation_error - ) - return await error_response( - error="server_error", error_description="An unexpected error occurred" - ) + logger.exception("Unexpected error in authorization_handler", exc_info=validation_error) + return await error_response(error="server_error", error_description="An unexpected error occurred") diff --git a/src/mcp/server/auth/handlers/register.py b/src/mcp/server/auth/handlers/register.py index 2e25c779a..61e403aca 100644 --- a/src/mcp/server/auth/handlers/register.py +++ b/src/mcp/server/auth/handlers/register.py @@ -10,11 +10,7 @@ from mcp.server.auth.errors import stringify_pydantic_error from mcp.server.auth.json_response import PydanticJSONResponse -from mcp.server.auth.provider import ( - OAuthAuthorizationServerProvider, - RegistrationError, - RegistrationErrorCode, -) +from mcp.server.auth.provider import OAuthAuthorizationServerProvider, RegistrationError, RegistrationErrorCode from mcp.server.auth.settings import ClientRegistrationOptions from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata @@ -60,9 +56,7 @@ async def handle(self, request: Request) -> Response: if client_metadata.scope is None and self.options.default_scopes is not None: client_metadata.scope = " ".join(self.options.default_scopes) - elif ( - client_metadata.scope is not None and self.options.valid_scopes is not None - ): + elif client_metadata.scope is not None and self.options.valid_scopes is not None: requested_scopes = set(client_metadata.scope.split()) valid_scopes = set(self.options.valid_scopes) if not requested_scopes.issubset(valid_scopes): @@ -78,8 +72,7 @@ async def handle(self, request: Request) -> Response: return PydanticJSONResponse( content=RegistrationErrorResponse( error="invalid_client_metadata", - error_description="grant_types must be authorization_code " - "and refresh_token", + error_description="grant_types must be authorization_code " "and refresh_token", ), status_code=400, ) @@ -122,8 +115,6 @@ async def handle(self, request: Request) -> Response: except RegistrationError as e: # Handle registration errors as defined in RFC 7591 Section 3.2.2 return PydanticJSONResponse( - content=RegistrationErrorResponse( - error=e.error, error_description=e.error_description - ), + content=RegistrationErrorResponse(error=e.error, error_description=e.error_description), status_code=400, ) diff --git a/src/mcp/server/auth/handlers/revoke.py b/src/mcp/server/auth/handlers/revoke.py index 43b4dded9..478ad7a01 100644 --- a/src/mcp/server/auth/handlers/revoke.py +++ b/src/mcp/server/auth/handlers/revoke.py @@ -10,15 +10,8 @@ stringify_pydantic_error, ) from mcp.server.auth.json_response import PydanticJSONResponse -from mcp.server.auth.middleware.client_auth import ( - AuthenticationError, - ClientAuthenticator, -) -from mcp.server.auth.provider import ( - AccessToken, - OAuthAuthorizationServerProvider, - RefreshToken, -) +from mcp.server.auth.middleware.client_auth import AuthenticationError, ClientAuthenticator +from mcp.server.auth.provider import AccessToken, OAuthAuthorizationServerProvider, RefreshToken class RevocationRequest(BaseModel): diff --git a/src/mcp/server/auth/handlers/token.py b/src/mcp/server/auth/handlers/token.py index 94a5c4de3..3b08f0960 100644 --- a/src/mcp/server/auth/handlers/token.py +++ b/src/mcp/server/auth/handlers/token.py @@ -7,19 +7,10 @@ from pydantic import AnyHttpUrl, BaseModel, Field, RootModel, ValidationError from starlette.requests import Request -from mcp.server.auth.errors import ( - stringify_pydantic_error, -) +from mcp.server.auth.errors import stringify_pydantic_error from mcp.server.auth.json_response import PydanticJSONResponse -from mcp.server.auth.middleware.client_auth import ( - AuthenticationError, - ClientAuthenticator, -) -from mcp.server.auth.provider import ( - OAuthAuthorizationServerProvider, - TokenError, - TokenErrorCode, -) +from mcp.server.auth.middleware.client_auth import AuthenticationError, ClientAuthenticator +from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenError, TokenErrorCode from mcp.shared.auth import OAuthToken @@ -27,9 +18,7 @@ class AuthorizationCodeRequest(BaseModel): # See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.3 grant_type: Literal["authorization_code"] code: str = Field(..., description="The authorization code") - redirect_uri: AnyHttpUrl | None = Field( - None, description="Must be the same as redirect URI provided in /authorize" - ) + redirect_uri: AnyHttpUrl | None = Field(None, description="Must be the same as redirect URI provided in /authorize") client_id: str # we use the client_secret param, per https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1 client_secret: str | None = None @@ -127,8 +116,7 @@ async def handle(self, request: Request): TokenErrorResponse( error="unsupported_grant_type", error_description=( - f"Unsupported grant type (supported grant types are " - f"{client_info.grant_types})" + f"Unsupported grant type (supported grant types are " f"{client_info.grant_types})" ), ) ) @@ -137,9 +125,7 @@ async def handle(self, request: Request): match token_request: case AuthorizationCodeRequest(): - auth_code = await self.provider.load_authorization_code( - client_info, token_request.code - ) + auth_code = await self.provider.load_authorization_code(client_info, token_request.code) if auth_code is None or auth_code.client_id != token_request.client_id: # if code belongs to different client, pretend it doesn't exist return self.response( @@ -169,18 +155,13 @@ async def handle(self, request: Request): return self.response( TokenErrorResponse( error="invalid_request", - error_description=( - "redirect_uri did not match the one " - "used when creating auth code" - ), + error_description=("redirect_uri did not match the one " "used when creating auth code"), ) ) # Verify PKCE code verifier sha256 = hashlib.sha256(token_request.code_verifier.encode()).digest() - hashed_code_verifier = ( - base64.urlsafe_b64encode(sha256).decode().rstrip("=") - ) + hashed_code_verifier = base64.urlsafe_b64encode(sha256).decode().rstrip("=") if hashed_code_verifier != auth_code.code_challenge: # see https://datatracker.ietf.org/doc/html/rfc7636#section-4.6 @@ -193,9 +174,7 @@ async def handle(self, request: Request): try: # Exchange authorization code for tokens - tokens = await self.provider.exchange_authorization_code( - client_info, auth_code - ) + tokens = await self.provider.exchange_authorization_code(client_info, auth_code) except TokenError as e: return self.response( TokenErrorResponse( @@ -205,13 +184,8 @@ async def handle(self, request: Request): ) case RefreshTokenRequest(): - refresh_token = await self.provider.load_refresh_token( - client_info, token_request.refresh_token - ) - if ( - refresh_token is None - or refresh_token.client_id != token_request.client_id - ): + refresh_token = await self.provider.load_refresh_token(client_info, token_request.refresh_token) + if refresh_token is None or refresh_token.client_id != token_request.client_id: # if token belongs to different client, pretend it doesn't exist return self.response( TokenErrorResponse( @@ -230,29 +204,20 @@ async def handle(self, request: Request): ) # Parse scopes if provided - scopes = ( - token_request.scope.split(" ") - if token_request.scope - else refresh_token.scopes - ) + scopes = token_request.scope.split(" ") if token_request.scope else refresh_token.scopes for scope in scopes: if scope not in refresh_token.scopes: return self.response( TokenErrorResponse( error="invalid_scope", - error_description=( - f"cannot request scope `{scope}` " - "not provided by refresh token" - ), + error_description=(f"cannot request scope `{scope}` " "not provided by refresh token"), ) ) try: # Exchange refresh token for new tokens - tokens = await self.provider.exchange_refresh_token( - client_info, refresh_token, scopes - ) + tokens = await self.provider.exchange_refresh_token(client_info, refresh_token, scopes) except TokenError as e: return self.response( TokenErrorResponse( diff --git a/src/mcp/server/auth/middleware/auth_context.py b/src/mcp/server/auth/middleware/auth_context.py index 1073c07ad..e2116c3bf 100644 --- a/src/mcp/server/auth/middleware/auth_context.py +++ b/src/mcp/server/auth/middleware/auth_context.py @@ -7,9 +7,7 @@ # Create a contextvar to store the authenticated user # The default is None, indicating no authenticated user is present -auth_context_var = contextvars.ContextVar[AuthenticatedUser | None]( - "auth_context", default=None -) +auth_context_var = contextvars.ContextVar[AuthenticatedUser | None]("auth_context", default=None) def get_access_token() -> AccessToken | None: diff --git a/src/mcp/server/auth/middleware/bearer_auth.py b/src/mcp/server/auth/middleware/bearer_auth.py index 30b5e2ba6..2fe1342b7 100644 --- a/src/mcp/server/auth/middleware/bearer_auth.py +++ b/src/mcp/server/auth/middleware/bearer_auth.py @@ -1,11 +1,7 @@ import time from typing import Any -from starlette.authentication import ( - AuthCredentials, - AuthenticationBackend, - SimpleUser, -) +from starlette.authentication import AuthCredentials, AuthenticationBackend, SimpleUser from starlette.exceptions import HTTPException from starlette.requests import HTTPConnection from starlette.types import Receive, Scope, Send @@ -35,11 +31,7 @@ def __init__( async def authenticate(self, conn: HTTPConnection): auth_header = next( - ( - conn.headers.get(key) - for key in conn.headers - if key.lower() == "authorization" - ), + (conn.headers.get(key) for key in conn.headers if key.lower() == "authorization"), None, ) if not auth_header or not auth_header.lower().startswith("bearer "): @@ -87,10 +79,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: for required_scope in self.required_scopes: # auth_credentials should always be provided; this is just paranoia - if ( - auth_credentials is None - or required_scope not in auth_credentials.scopes - ): + if auth_credentials is None or required_scope not in auth_credentials.scopes: raise HTTPException(status_code=403, detail="Insufficient scope") await self.app(scope, receive, send) diff --git a/src/mcp/server/auth/middleware/client_auth.py b/src/mcp/server/auth/middleware/client_auth.py index 37f7f5066..d5f473b48 100644 --- a/src/mcp/server/auth/middleware/client_auth.py +++ b/src/mcp/server/auth/middleware/client_auth.py @@ -30,9 +30,7 @@ def __init__(self, provider: OAuthAuthorizationServerProvider[Any, Any, Any]): """ self.provider = provider - async def authenticate( - self, client_id: str, client_secret: str | None - ) -> OAuthClientInformationFull: + async def authenticate(self, client_id: str, client_secret: str | None) -> OAuthClientInformationFull: # Look up client information client = await self.provider.get_client(client_id) if not client: @@ -47,10 +45,7 @@ async def authenticate( if client.client_secret != client_secret: raise AuthenticationError("Invalid client_secret") - if ( - client.client_secret_expires_at - and client.client_secret_expires_at < int(time.time()) - ): + if client.client_secret_expires_at and client.client_secret_expires_at < int(time.time()): raise AuthenticationError("Client secret has expired") return client diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index be1ac1dbc..eea34a15b 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -4,10 +4,7 @@ from pydantic import AnyHttpUrl, BaseModel -from mcp.shared.auth import ( - OAuthClientInformationFull, - OAuthToken, -) +from mcp.shared.auth import OAuthClientInformationFull, OAuthToken class AuthorizationParams(BaseModel): @@ -96,9 +93,7 @@ class TokenError(Exception): AccessTokenT = TypeVar("AccessTokenT", bound=AccessToken) -class OAuthAuthorizationServerProvider( - Protocol, Generic[AuthorizationCodeT, RefreshTokenT, AccessTokenT] -): +class OAuthAuthorizationServerProvider(Protocol, Generic[AuthorizationCodeT, RefreshTokenT, AccessTokenT]): async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: """ Retrieves client information by client ID. @@ -129,9 +124,7 @@ async def register_client(self, client_info: OAuthClientInformationFull) -> None """ ... - async def authorize( - self, client: OAuthClientInformationFull, params: AuthorizationParams - ) -> str: + async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str: """ Called as part of the /authorize endpoint, and returns a URL that the client will be redirected to. @@ -207,9 +200,7 @@ async def exchange_authorization_code( """ ... - async def load_refresh_token( - self, client: OAuthClientInformationFull, refresh_token: str - ) -> RefreshTokenT | None: + async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_token: str) -> RefreshTokenT | None: """ Loads a RefreshToken by its token string. diff --git a/src/mcp/server/auth/routes.py b/src/mcp/server/auth/routes.py index d588d78ee..dff468ebd 100644 --- a/src/mcp/server/auth/routes.py +++ b/src/mcp/server/auth/routes.py @@ -31,11 +31,7 @@ def validate_issuer_url(url: AnyHttpUrl): """ # RFC 8414 requires HTTPS, but we allow localhost HTTP for testing - if ( - url.scheme != "https" - and url.host != "localhost" - and not url.host.startswith("127.0.0.1") - ): + if url.scheme != "https" and url.host != "localhost" and not url.host.startswith("127.0.0.1"): raise ValueError("Issuer URL must be HTTPS") # No fragments or query parameters allowed @@ -73,9 +69,7 @@ def create_auth_routes( ) -> list[Route]: validate_issuer_url(issuer_url) - client_registration_options = ( - client_registration_options or ClientRegistrationOptions() - ) + client_registration_options = client_registration_options or ClientRegistrationOptions() revocation_options = revocation_options or RevocationOptions() metadata = build_metadata( issuer_url, @@ -177,15 +171,11 @@ def build_metadata( # Add registration endpoint if supported if client_registration_options.enabled: - metadata.registration_endpoint = AnyHttpUrl( - str(issuer_url).rstrip("/") + REGISTRATION_PATH - ) + metadata.registration_endpoint = AnyHttpUrl(str(issuer_url).rstrip("/") + REGISTRATION_PATH) # Add revocation endpoint if supported if revocation_options.enabled: - metadata.revocation_endpoint = AnyHttpUrl( - str(issuer_url).rstrip("/") + REVOCATION_PATH - ) + metadata.revocation_endpoint = AnyHttpUrl(str(issuer_url).rstrip("/") + REVOCATION_PATH) metadata.revocation_endpoint_auth_methods_supported = ["client_secret_post"] return metadata diff --git a/src/mcp/server/auth/settings.py b/src/mcp/server/auth/settings.py index 1086bb77e..7306d91af 100644 --- a/src/mcp/server/auth/settings.py +++ b/src/mcp/server/auth/settings.py @@ -15,8 +15,7 @@ class RevocationOptions(BaseModel): class AuthSettings(BaseModel): issuer_url: AnyHttpUrl = Field( ..., - description="URL advertised as OAuth issuer; this should be the URL the server " - "is reachable at", + description="URL advertised as OAuth issuer; this should be the URL the server " "is reachable at", ) service_documentation_url: AnyHttpUrl | None = None client_registration_options: ClientRegistrationOptions | None = None diff --git a/src/mcp/server/fastmcp/prompts/base.py b/src/mcp/server/fastmcp/prompts/base.py index aa3d1eac9..d9ecc09c3 100644 --- a/src/mcp/server/fastmcp/prompts/base.py +++ b/src/mcp/server/fastmcp/prompts/base.py @@ -42,13 +42,9 @@ def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any): super().__init__(content=content, **kwargs) -message_validator = TypeAdapter[UserMessage | AssistantMessage]( - UserMessage | AssistantMessage -) +message_validator = TypeAdapter[UserMessage | AssistantMessage](UserMessage | AssistantMessage) -SyncPromptResult = ( - str | Message | dict[str, Any] | Sequence[str | Message | dict[str, Any]] -) +SyncPromptResult = str | Message | dict[str, Any] | Sequence[str | Message | dict[str, Any]] PromptResult = SyncPromptResult | Awaitable[SyncPromptResult] @@ -56,24 +52,16 @@ class PromptArgument(BaseModel): """An argument that can be passed to a prompt.""" name: str = Field(description="Name of the argument") - description: str | None = Field( - None, description="Description of what the argument does" - ) - required: bool = Field( - default=False, description="Whether the argument is required" - ) + description: str | None = Field(None, description="Description of what the argument does") + required: bool = Field(default=False, description="Whether the argument is required") class Prompt(BaseModel): """A prompt template that can be rendered with parameters.""" name: str = Field(description="Name of the prompt") - description: str | None = Field( - None, description="Description of what the prompt does" - ) - arguments: list[PromptArgument] | None = Field( - None, description="Arguments that can be passed to the prompt" - ) + description: str | None = Field(None, description="Description of what the prompt does") + arguments: list[PromptArgument] | None = Field(None, description="Arguments that can be passed to the prompt") fn: Callable[..., PromptResult | Awaitable[PromptResult]] = Field(exclude=True) @classmethod @@ -154,14 +142,10 @@ async def render(self, arguments: dict[str, Any] | None = None) -> list[Message] content = TextContent(type="text", text=msg) messages.append(UserMessage(content=content)) else: - content = pydantic_core.to_json( - msg, fallback=str, indent=2 - ).decode() + content = pydantic_core.to_json(msg, fallback=str, indent=2).decode() messages.append(Message(role="user", content=content)) except Exception: - raise ValueError( - f"Could not convert prompt result to message: {msg}" - ) + raise ValueError(f"Could not convert prompt result to message: {msg}") return messages except Exception as e: diff --git a/src/mcp/server/fastmcp/prompts/manager.py b/src/mcp/server/fastmcp/prompts/manager.py index 7ccbdef36..6b01d91cd 100644 --- a/src/mcp/server/fastmcp/prompts/manager.py +++ b/src/mcp/server/fastmcp/prompts/manager.py @@ -39,9 +39,7 @@ def add_prompt( self._prompts[prompt.name] = prompt return prompt - async def render_prompt( - self, name: str, arguments: dict[str, Any] | None = None - ) -> list[Message]: + async def render_prompt(self, name: str, arguments: dict[str, Any] | None = None) -> list[Message]: """Render a prompt by name with arguments.""" prompt = self.get_prompt(name) if not prompt: diff --git a/src/mcp/server/fastmcp/resources/base.py b/src/mcp/server/fastmcp/resources/base.py index b2050e7f8..24bf7d531 100644 --- a/src/mcp/server/fastmcp/resources/base.py +++ b/src/mcp/server/fastmcp/resources/base.py @@ -19,13 +19,9 @@ class Resource(BaseModel, abc.ABC): model_config = ConfigDict(validate_default=True) - uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] = Field( - default=..., description="URI of the resource" - ) + uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] = Field(default=..., description="URI of the resource") name: str | None = Field(description="Name of the resource", default=None) - description: str | None = Field( - description="Description of the resource", default=None - ) + description: str | None = Field(description="Description of the resource", default=None) mime_type: str = Field( default="text/plain", description="MIME type of the resource content", diff --git a/src/mcp/server/fastmcp/resources/templates.py b/src/mcp/server/fastmcp/resources/templates.py index a30b18253..19d7dedd1 100644 --- a/src/mcp/server/fastmcp/resources/templates.py +++ b/src/mcp/server/fastmcp/resources/templates.py @@ -15,18 +15,12 @@ class ResourceTemplate(BaseModel): """A template for dynamically creating resources.""" - uri_template: str = Field( - description="URI template with parameters (e.g. weather://{city}/current)" - ) + uri_template: str = Field(description="URI template with parameters (e.g. weather://{city}/current)") name: str = Field(description="Name of the resource") description: str | None = Field(description="Description of what the resource does") - mime_type: str = Field( - default="text/plain", description="MIME type of the resource content" - ) + mime_type: str = Field(default="text/plain", description="MIME type of the resource content") fn: Callable[..., Any] = Field(exclude=True) - parameters: dict[str, Any] = Field( - description="JSON schema for function parameters" - ) + parameters: dict[str, Any] = Field(description="JSON schema for function parameters") @classmethod def from_function( diff --git a/src/mcp/server/fastmcp/resources/types.py b/src/mcp/server/fastmcp/resources/types.py index d3f10211d..63377f423 100644 --- a/src/mcp/server/fastmcp/resources/types.py +++ b/src/mcp/server/fastmcp/resources/types.py @@ -54,9 +54,7 @@ class FunctionResource(Resource): async def read(self) -> str | bytes: """Read the resource by calling the wrapped function.""" try: - result = ( - await self.fn() if inspect.iscoroutinefunction(self.fn) else self.fn() - ) + result = await self.fn() if inspect.iscoroutinefunction(self.fn) else self.fn() if isinstance(result, Resource): return await result.read() elif isinstance(result, bytes): @@ -141,9 +139,7 @@ class HttpResource(Resource): """A resource that reads from an HTTP endpoint.""" url: str = Field(description="URL to fetch content from") - mime_type: str = Field( - default="application/json", description="MIME type of the resource content" - ) + mime_type: str = Field(default="application/json", description="MIME type of the resource content") async def read(self) -> str | bytes: """Read the HTTP content.""" @@ -157,15 +153,9 @@ class DirectoryResource(Resource): """A resource that lists files in a directory.""" path: Path = Field(description="Path to the directory") - recursive: bool = Field( - default=False, description="Whether to list files recursively" - ) - pattern: str | None = Field( - default=None, description="Optional glob pattern to filter files" - ) - mime_type: str = Field( - default="application/json", description="MIME type of the resource content" - ) + recursive: bool = Field(default=False, description="Whether to list files recursively") + pattern: str | None = Field(default=None, description="Optional glob pattern to filter files") + mime_type: str = Field(default="application/json", description="MIME type of the resource content") @pydantic.field_validator("path") @classmethod @@ -184,16 +174,8 @@ def list_files(self) -> list[Path]: try: if self.pattern: - return ( - list(self.path.glob(self.pattern)) - if not self.recursive - else list(self.path.rglob(self.pattern)) - ) - return ( - list(self.path.glob("*")) - if not self.recursive - else list(self.path.rglob("*")) - ) + return list(self.path.glob(self.pattern)) if not self.recursive else list(self.path.rglob(self.pattern)) + return list(self.path.glob("*")) if not self.recursive else list(self.path.rglob("*")) except Exception as e: raise ValueError(f"Error listing directory {self.path}: {e}") diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index e5b6c3acc..94f5fa082 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -96,9 +96,7 @@ class Settings(BaseSettings, Generic[LifespanResultT]): # StreamableHTTP settings json_response: bool = False - stateless_http: bool = ( - False # If True, uses true stateless mode (new transport per request) - ) + stateless_http: bool = False # If True, uses true stateless mode (new transport per request) # resource settings warn_on_duplicate_resources: bool = True @@ -114,9 +112,9 @@ class Settings(BaseSettings, Generic[LifespanResultT]): description="List of dependencies to install in the server environment", ) - lifespan: ( - Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]] | None - ) = Field(None, description="Lifespan context manager") + lifespan: Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]] | None = Field( + None, description="Lifespan context manager" + ) auth: AuthSettings | None = None @@ -124,9 +122,7 @@ class Settings(BaseSettings, Generic[LifespanResultT]): def lifespan_wrapper( app: FastMCP, lifespan: Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]], -) -> Callable[ - [MCPServer[LifespanResultT, Request]], AbstractAsyncContextManager[object] -]: +) -> Callable[[MCPServer[LifespanResultT, Request]], AbstractAsyncContextManager[object]]: @asynccontextmanager async def wrap(s: MCPServer[LifespanResultT, Request]) -> AsyncIterator[object]: async with lifespan(app) as context: @@ -140,8 +136,7 @@ def __init__( self, name: str | None = None, instructions: str | None = None, - auth_server_provider: OAuthAuthorizationServerProvider[Any, Any, Any] - | None = None, + auth_server_provider: OAuthAuthorizationServerProvider[Any, Any, Any] | None = None, event_store: EventStore | None = None, *, tools: list[Tool] | None = None, @@ -152,31 +147,18 @@ def __init__( self._mcp_server = MCPServer( name=name or "FastMCP", instructions=instructions, - lifespan=( - lifespan_wrapper(self, self.settings.lifespan) - if self.settings.lifespan - else default_lifespan - ), - ) - self._tool_manager = ToolManager( - tools=tools, warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools - ) - self._resource_manager = ResourceManager( - warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources - ) - self._prompt_manager = PromptManager( - warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts + lifespan=(lifespan_wrapper(self, self.settings.lifespan) if self.settings.lifespan else default_lifespan), ) + self._tool_manager = ToolManager(tools=tools, warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools) + self._resource_manager = ResourceManager(warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources) + self._prompt_manager = PromptManager(warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts) if (self.settings.auth is not None) != (auth_server_provider is not None): # TODO: after we support separate authorization servers (see # https://github.com/modelcontextprotocol/modelcontextprotocol/pull/284) # we should validate that if auth is enabled, we have either an # auth_server_provider to host our own authorization server, # OR the URL of a 3rd party authorization server. - raise ValueError( - "settings.auth must be specified if and only if auth_server_provider " - "is specified" - ) + raise ValueError("settings.auth must be specified if and only if auth_server_provider " "is specified") self._auth_server_provider = auth_server_provider self._event_store = event_store self._custom_starlette_routes: list[Route] = [] @@ -339,9 +321,7 @@ def add_tool( description: Optional description of what the tool does annotations: Optional ToolAnnotations providing additional tool information """ - self._tool_manager.add_tool( - fn, name=name, description=description, annotations=annotations - ) + self._tool_manager.add_tool(fn, name=name, description=description, annotations=annotations) def tool( self, @@ -378,14 +358,11 @@ async def async_tool(x: int, context: Context) -> str: # Check if user passed function directly instead of calling decorator if callable(name): raise TypeError( - "The @tool decorator was used incorrectly. " - "Did you forget to call it? Use @tool() instead of @tool" + "The @tool decorator was used incorrectly. " "Did you forget to call it? Use @tool() instead of @tool" ) def decorator(fn: AnyFunction) -> AnyFunction: - self.add_tool( - fn, name=name, description=description, annotations=annotations - ) + self.add_tool(fn, name=name, description=description, annotations=annotations) return fn return decorator @@ -461,8 +438,7 @@ def decorator(fn: AnyFunction) -> AnyFunction: if uri_params != func_params: raise ValueError( - f"Mismatch between URI parameters {uri_params} " - f"and function parameters {func_params}" + f"Mismatch between URI parameters {uri_params} " f"and function parameters {func_params}" ) # Register as template @@ -495,9 +471,7 @@ def add_prompt(self, prompt: Prompt) -> None: """ self._prompt_manager.add_prompt(prompt) - def prompt( - self, name: str | None = None, description: str | None = None - ) -> Callable[[AnyFunction], AnyFunction]: + def prompt(self, name: str | None = None, description: str | None = None) -> Callable[[AnyFunction], AnyFunction]: """Decorator to register a prompt. Args: @@ -664,9 +638,7 @@ def sse_app(self, mount_path: str | None = None) -> Starlette: self.settings.mount_path = mount_path # Create normalized endpoint considering the mount path - normalized_message_endpoint = self._normalize_path( - self.settings.mount_path, self.settings.message_path - ) + normalized_message_endpoint = self._normalize_path(self.settings.mount_path, self.settings.message_path) # Set up auth context and dependencies @@ -763,9 +735,7 @@ async def sse_endpoint(request: Request) -> Response: routes.extend(self._custom_starlette_routes) # Create Starlette app with routes and middleware - return Starlette( - debug=self.settings.debug, routes=routes, middleware=middleware - ) + return Starlette(debug=self.settings.debug, routes=routes, middleware=middleware) def streamable_http_app(self) -> Starlette: """Return an instance of the StreamableHTTP server app.""" @@ -782,9 +752,7 @@ def streamable_http_app(self) -> Starlette: ) # Create the ASGI handler - async def handle_streamable_http( - scope: Scope, receive: Receive, send: Send - ) -> None: + async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> None: await self.session_manager.handle_request(scope, receive, send) # Create routes @@ -860,9 +828,7 @@ async def list_prompts(self) -> list[MCPPrompt]: for prompt in prompts ] - async def get_prompt( - self, name: str, arguments: dict[str, Any] | None = None - ) -> GetPromptResult: + async def get_prompt(self, name: str, arguments: dict[str, Any] | None = None) -> GetPromptResult: """Get a prompt by name with arguments.""" try: messages = await self._prompt_manager.render_prompt(name, arguments) @@ -935,9 +901,7 @@ def my_tool(x: int, ctx: Context) -> str: def __init__( self, *, - request_context: ( - RequestContext[ServerSessionT, LifespanContextT, RequestT] | None - ) = None, + request_context: (RequestContext[ServerSessionT, LifespanContextT, RequestT] | None) = None, fastmcp: FastMCP | None = None, **kwargs: Any, ): @@ -961,9 +925,7 @@ def request_context( raise ValueError("Context is not available outside of a request") return self._request_context - async def report_progress( - self, progress: float, total: float | None = None, message: str | None = None - ) -> None: + async def report_progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: """Report progress for the current operation. Args: @@ -971,11 +933,7 @@ async def report_progress( total: Optional total value e.g. 100 message: Optional message e.g. Starting render... """ - progress_token = ( - self.request_context.meta.progressToken - if self.request_context.meta - else None - ) + progress_token = self.request_context.meta.progressToken if self.request_context.meta else None if progress_token is None: return @@ -996,9 +954,7 @@ async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContent Returns: The resource content as either text or bytes """ - assert ( - self._fastmcp is not None - ), "Context is not available outside of a request" + assert self._fastmcp is not None, "Context is not available outside of a request" return await self._fastmcp.read_resource(uri) async def log( @@ -1026,11 +982,7 @@ async def log( @property def client_id(self) -> str | None: """Get the client ID if available.""" - return ( - getattr(self.request_context.meta, "client_id", None) - if self.request_context.meta - else None - ) + return getattr(self.request_context.meta, "client_id", None) if self.request_context.meta else None @property def request_id(self) -> str: diff --git a/src/mcp/server/fastmcp/tools/base.py b/src/mcp/server/fastmcp/tools/base.py index f32eb15bd..b25597215 100644 --- a/src/mcp/server/fastmcp/tools/base.py +++ b/src/mcp/server/fastmcp/tools/base.py @@ -25,16 +25,11 @@ class Tool(BaseModel): description: str = Field(description="Description of what the tool does") parameters: dict[str, Any] = Field(description="JSON schema for tool parameters") fn_metadata: FuncMetadata = Field( - description="Metadata about the function including a pydantic model for tool" - " arguments" + description="Metadata about the function including a pydantic model for tool" " arguments" ) is_async: bool = Field(description="Whether the tool is async") - context_kwarg: str | None = Field( - None, description="Name of the kwarg that should receive context" - ) - annotations: ToolAnnotations | None = Field( - None, description="Optional annotations for the tool" - ) + context_kwarg: str | None = Field(None, description="Name of the kwarg that should receive context") + annotations: ToolAnnotations | None = Field(None, description="Optional annotations for the tool") @classmethod def from_function( @@ -93,9 +88,7 @@ async def run( self.fn, self.is_async, arguments, - {self.context_kwarg: context} - if self.context_kwarg is not None - else None, + {self.context_kwarg: context} if self.context_kwarg is not None else None, ) except Exception as e: raise ToolError(f"Error executing tool {self.name}: {e}") from e diff --git a/src/mcp/server/fastmcp/tools/tool_manager.py b/src/mcp/server/fastmcp/tools/tool_manager.py index 153249379..1cd301299 100644 --- a/src/mcp/server/fastmcp/tools/tool_manager.py +++ b/src/mcp/server/fastmcp/tools/tool_manager.py @@ -50,9 +50,7 @@ def add_tool( annotations: ToolAnnotations | None = None, ) -> Tool: """Add a tool to the server.""" - tool = Tool.from_function( - fn, name=name, description=description, annotations=annotations - ) + tool = Tool.from_function(fn, name=name, description=description, annotations=annotations) existing = self._tools.get(tool.name) if existing: if self.warn_on_duplicate_tools: diff --git a/src/mcp/server/fastmcp/utilities/func_metadata.py b/src/mcp/server/fastmcp/utilities/func_metadata.py index 374391325..9f8d9177a 100644 --- a/src/mcp/server/fastmcp/utilities/func_metadata.py +++ b/src/mcp/server/fastmcp/utilities/func_metadata.py @@ -102,9 +102,7 @@ def pre_parse_json(self, data: dict[str, Any]) -> dict[str, Any]: ) -def func_metadata( - func: Callable[..., Any], skip_names: Sequence[str] = () -) -> FuncMetadata: +def func_metadata(func: Callable[..., Any], skip_names: Sequence[str] = ()) -> FuncMetadata: """Given a function, return metadata including a pydantic model representing its signature. @@ -131,9 +129,7 @@ def func_metadata( globalns = getattr(func, "__globals__", {}) for param in params.values(): if param.name.startswith("_"): - raise InvalidSignature( - f"Parameter {param.name} of {func.__name__} cannot start with '_'" - ) + raise InvalidSignature(f"Parameter {param.name} of {func.__name__} cannot start with '_'") if param.name in skip_names: continue annotation = param.annotation @@ -142,11 +138,7 @@ def func_metadata( if annotation is None: annotation = Annotated[ None, - Field( - default=param.default - if param.default is not inspect.Parameter.empty - else PydanticUndefined - ), + Field(default=param.default if param.default is not inspect.Parameter.empty else PydanticUndefined), ] # Untyped field @@ -160,9 +152,7 @@ def func_metadata( field_info = FieldInfo.from_annotated_attribute( _get_typed_annotation(annotation, globalns), - param.default - if param.default is not inspect.Parameter.empty - else PydanticUndefined, + param.default if param.default is not inspect.Parameter.empty else PydanticUndefined, ) dynamic_pydantic_model_params[param.name] = (field_info.annotation, field_info) continue @@ -177,9 +167,7 @@ def func_metadata( def _get_typed_annotation(annotation: Any, globalns: dict[str, Any]) -> Any: - def try_eval_type( - value: Any, globalns: dict[str, Any], localns: dict[str, Any] - ) -> tuple[Any, bool]: + def try_eval_type(value: Any, globalns: dict[str, Any], localns: dict[str, Any]) -> tuple[Any, bool]: try: return eval_type_backport(value, globalns, localns), True except NameError: diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index b98e3dd1a..4bfc5f620 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -95,9 +95,7 @@ async def main(): RequestT = TypeVar("RequestT", default=Any) # This will be properly typed in each Server instance's context -request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any, Any]] = ( - contextvars.ContextVar("request_ctx") -) +request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any, Any]] = contextvars.ContextVar("request_ctx") class NotificationOptions: @@ -140,9 +138,7 @@ def __init__( self.version = version self.instructions = instructions self.lifespan = lifespan - self.request_handlers: dict[ - type, Callable[..., Awaitable[types.ServerResult]] - ] = { + self.request_handlers: dict[type, Callable[..., Awaitable[types.ServerResult]]] = { types.PingRequest: _ping_handler, } self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {} @@ -189,9 +185,7 @@ def get_capabilities( # Set prompt capabilities if handler exists if types.ListPromptsRequest in self.request_handlers: - prompts_capability = types.PromptsCapability( - listChanged=notification_options.prompts_changed - ) + prompts_capability = types.PromptsCapability(listChanged=notification_options.prompts_changed) # Set resource capabilities if handler exists if types.ListResourcesRequest in self.request_handlers: @@ -201,9 +195,7 @@ def get_capabilities( # Set tool capabilities if handler exists if types.ListToolsRequest in self.request_handlers: - tools_capability = types.ToolsCapability( - listChanged=notification_options.tools_changed - ) + tools_capability = types.ToolsCapability(listChanged=notification_options.tools_changed) # Set logging capabilities if handler exists if types.SetLevelRequest in self.request_handlers: @@ -239,9 +231,7 @@ async def handler(_: Any): def get_prompt(self): def decorator( - func: Callable[ - [str, dict[str, str] | None], Awaitable[types.GetPromptResult] - ], + func: Callable[[str, dict[str, str] | None], Awaitable[types.GetPromptResult]], ): logger.debug("Registering handler for GetPromptRequest") @@ -260,9 +250,7 @@ def decorator(func: Callable[[], Awaitable[list[types.Resource]]]): async def handler(_: Any): resources = await func() - return types.ServerResult( - types.ListResourcesResult(resources=resources) - ) + return types.ServerResult(types.ListResourcesResult(resources=resources)) self.request_handlers[types.ListResourcesRequest] = handler return func @@ -275,9 +263,7 @@ def decorator(func: Callable[[], Awaitable[list[types.ResourceTemplate]]]): async def handler(_: Any): templates = await func() - return types.ServerResult( - types.ListResourceTemplatesResult(resourceTemplates=templates) - ) + return types.ServerResult(types.ListResourceTemplatesResult(resourceTemplates=templates)) self.request_handlers[types.ListResourceTemplatesRequest] = handler return func @@ -286,9 +272,7 @@ async def handler(_: Any): def read_resource(self): def decorator( - func: Callable[ - [AnyUrl], Awaitable[str | bytes | Iterable[ReadResourceContents]] - ], + func: Callable[[AnyUrl], Awaitable[str | bytes | Iterable[ReadResourceContents]]], ): logger.debug("Registering handler for ReadResourceRequest") @@ -323,8 +307,7 @@ def create_content(data: str | bytes, mime_type: str | None): content = create_content(data, None) case Iterable() as contents: contents_list = [ - create_content(content_item.content, content_item.mime_type) - for content_item in contents + create_content(content_item.content, content_item.mime_type) for content_item in contents ] return types.ServerResult( types.ReadResourceResult( @@ -332,9 +315,7 @@ def create_content(data: str | bytes, mime_type: str | None): ) ) case _: - raise ValueError( - f"Unexpected return type from read_resource: {type(result)}" - ) + raise ValueError(f"Unexpected return type from read_resource: {type(result)}") return types.ServerResult( types.ReadResourceResult( @@ -403,11 +384,7 @@ def call_tool(self): def decorator( func: Callable[ ..., - Awaitable[ - Iterable[ - types.TextContent | types.ImageContent | types.EmbeddedResource - ] - ], + Awaitable[Iterable[types.TextContent | types.ImageContent | types.EmbeddedResource]], ], ): logger.debug("Registering handler for CallToolRequest") @@ -415,9 +392,7 @@ def decorator( async def handler(req: types.CallToolRequest): try: results = await func(req.params.name, (req.params.arguments or {})) - return types.ServerResult( - types.CallToolResult(content=list(results), isError=False) - ) + return types.ServerResult(types.CallToolResult(content=list(results), isError=False)) except Exception as e: return types.ServerResult( types.CallToolResult( @@ -433,9 +408,7 @@ async def handler(req: types.CallToolRequest): def progress_notification(self): def decorator( - func: Callable[ - [str | int, float, float | None, str | None], Awaitable[None] - ], + func: Callable[[str | int, float, float | None, str | None], Awaitable[None]], ): logger.debug("Registering handler for ProgressNotification") @@ -522,9 +495,7 @@ async def run( async def _handle_message( self, - message: RequestResponder[types.ClientRequest, types.ServerResult] - | types.ClientNotification - | Exception, + message: RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception, session: ServerSession, lifespan_context: LifespanResultT, raise_exceptions: bool = False, @@ -532,13 +503,9 @@ async def _handle_message( with warnings.catch_warnings(record=True) as w: # TODO(Marcelo): We should be checking if message is Exception here. match message: # type: ignore[reportMatchNotExhaustive] - case ( - RequestResponder(request=types.ClientRequest(root=req)) as responder - ): + case RequestResponder(request=types.ClientRequest(root=req)) as responder: with responder: - await self._handle_request( - message, req, session, lifespan_context, raise_exceptions - ) + await self._handle_request(message, req, session, lifespan_context, raise_exceptions) case types.ClientNotification(root=notify): await self._handle_notification(notify) @@ -562,9 +529,7 @@ async def _handle_request( try: # Extract request context from message metadata request_data = None - if message.message_metadata is not None and isinstance( - message.message_metadata, ServerMessageMetadata - ): + if message.message_metadata is not None and isinstance(message.message_metadata, ServerMessageMetadata): request_data = message.message_metadata.request_context # Set our global state that can be retrieved via diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index ef5c5a3c3..e6611b0d4 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -64,9 +64,7 @@ class InitializationState(Enum): ServerSessionT = TypeVar("ServerSessionT", bound="ServerSession") ServerRequestResponder = ( - RequestResponder[types.ClientRequest, types.ServerResult] - | types.ClientNotification - | Exception + RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception ) @@ -89,22 +87,16 @@ def __init__( init_options: InitializationOptions, stateless: bool = False, ) -> None: - super().__init__( - read_stream, write_stream, types.ClientRequest, types.ClientNotification - ) + super().__init__(read_stream, write_stream, types.ClientRequest, types.ClientNotification) self._initialization_state = ( - InitializationState.Initialized - if stateless - else InitializationState.NotInitialized + InitializationState.Initialized if stateless else InitializationState.NotInitialized ) self._init_options = init_options - self._incoming_message_stream_writer, self._incoming_message_stream_reader = ( - anyio.create_memory_object_stream[ServerRequestResponder](0) - ) - self._exit_stack.push_async_callback( - lambda: self._incoming_message_stream_reader.aclose() - ) + self._incoming_message_stream_writer, self._incoming_message_stream_reader = anyio.create_memory_object_stream[ + ServerRequestResponder + ](0) + self._exit_stack.push_async_callback(lambda: self._incoming_message_stream_reader.aclose()) @property def client_params(self) -> types.InitializeRequestParams | None: @@ -134,10 +126,7 @@ def check_client_capability(self, capability: types.ClientCapabilities) -> bool: return False # Check each experimental capability for exp_key, exp_value in capability.experimental.items(): - if ( - exp_key not in client_caps.experimental - or client_caps.experimental[exp_key] != exp_value - ): + if exp_key not in client_caps.experimental or client_caps.experimental[exp_key] != exp_value: return False return True @@ -146,9 +135,7 @@ async def _receive_loop(self) -> None: async with self._incoming_message_stream_writer: await super()._receive_loop() - async def _received_request( - self, responder: RequestResponder[types.ClientRequest, types.ServerResult] - ): + async def _received_request(self, responder: RequestResponder[types.ClientRequest, types.ServerResult]): match responder.request.root: case types.InitializeRequest(params=params): requested_version = params.protocolVersion @@ -172,13 +159,9 @@ async def _received_request( ) case _: if self._initialization_state != InitializationState.Initialized: - raise RuntimeError( - "Received request before initialization was complete" - ) + raise RuntimeError("Received request before initialization was complete") - async def _received_notification( - self, notification: types.ClientNotification - ) -> None: + async def _received_notification(self, notification: types.ClientNotification) -> None: # Need this to avoid ASYNC910 await anyio.lowlevel.checkpoint() match notification.root: @@ -186,9 +169,7 @@ async def _received_notification( self._initialization_state = InitializationState.Initialized case _: if self._initialization_state != InitializationState.Initialized: - raise RuntimeError( - "Received notification before initialization was complete" - ) + raise RuntimeError("Received notification before initialization was complete") async def send_log_message( self, diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index 192c1290b..52f273968 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -116,20 +116,14 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send): full_message_path_for_client = root_path.rstrip("/") + self._endpoint # This is the URI (path + query) the client will use to POST messages. - client_post_uri_data = ( - f"{quote(full_message_path_for_client)}?session_id={session_id.hex}" - ) + client_post_uri_data = f"{quote(full_message_path_for_client)}?session_id={session_id.hex}" - sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[ - dict[str, Any] - ](0) + sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, Any]](0) async def sse_writer(): logger.debug("Starting SSE writer") async with sse_stream_writer, write_stream_reader: - await sse_stream_writer.send( - {"event": "endpoint", "data": client_post_uri_data} - ) + await sse_stream_writer.send({"event": "endpoint", "data": client_post_uri_data}) logger.debug(f"Sent endpoint event: {client_post_uri_data}") async for session_message in write_stream_reader: @@ -137,9 +131,7 @@ async def sse_writer(): await sse_stream_writer.send( { "event": "message", - "data": session_message.message.model_dump_json( - by_alias=True, exclude_none=True - ), + "data": session_message.message.model_dump_json(by_alias=True, exclude_none=True), } ) @@ -151,9 +143,9 @@ async def response_wrapper(scope: Scope, receive: Receive, send: Send): In this case we close our side of the streams to signal the client that the connection has been closed. """ - await EventSourceResponse( - content=sse_stream_reader, data_sender_callable=sse_writer - )(scope, receive, send) + await EventSourceResponse(content=sse_stream_reader, data_sender_callable=sse_writer)( + scope, receive, send + ) await read_stream_writer.aclose() await write_stream_reader.aclose() logging.debug(f"Client session disconnected {session_id}") @@ -164,9 +156,7 @@ async def response_wrapper(scope: Scope, receive: Receive, send: Send): logger.debug("Yielding read and write streams") yield (read_stream, write_stream) - async def handle_post_message( - self, scope: Scope, receive: Receive, send: Send - ) -> None: + async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) -> None: logger.debug("Handling POST message") request = Request(scope, receive) diff --git a/src/mcp/server/stdio.py b/src/mcp/server/stdio.py index f0bbe5a31..d1618a371 100644 --- a/src/mcp/server/stdio.py +++ b/src/mcp/server/stdio.py @@ -76,9 +76,7 @@ async def stdout_writer(): try: async with write_stream_reader: async for session_message in write_stream_reader: - json = session_message.message.model_dump_json( - by_alias=True, exclude_none=True - ) + json = session_message.message.model_dump_json(by_alias=True, exclude_none=True) await stdout.write(json + "\n") await stdout.flush() except anyio.ClosedResourceError: diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index a94cc2834..9356a9948 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -82,9 +82,7 @@ class EventStore(ABC): """ @abstractmethod - async def store_event( - self, stream_id: StreamId, message: JSONRPCMessage - ) -> EventId: + async def store_event(self, stream_id: StreamId, message: JSONRPCMessage) -> EventId: """ Stores an event for later retrieval. @@ -125,9 +123,7 @@ class StreamableHTTPServerTransport: """ # Server notification streams for POST requests as well as standalone SSE stream - _read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] | None = ( - None - ) + _read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] | None = None _read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] | None = None _write_stream: MemoryObjectSendStream[SessionMessage] | None = None _write_stream_reader: MemoryObjectReceiveStream[SessionMessage] | None = None @@ -153,12 +149,8 @@ def __init__( Raises: ValueError: If the session ID contains invalid characters. """ - if mcp_session_id is not None and not SESSION_ID_PATTERN.fullmatch( - mcp_session_id - ): - raise ValueError( - "Session ID must only contain visible ASCII characters (0x21-0x7E)" - ) + if mcp_session_id is not None and not SESSION_ID_PATTERN.fullmatch(mcp_session_id): + raise ValueError("Session ID must only contain visible ASCII characters (0x21-0x7E)") self.mcp_session_id = mcp_session_id self.is_json_response_enabled = is_json_response_enabled @@ -218,9 +210,7 @@ def _create_json_response( response_headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id return Response( - response_message.model_dump_json(by_alias=True, exclude_none=True) - if response_message - else None, + response_message.model_dump_json(by_alias=True, exclude_none=True) if response_message else None, status_code=status_code, headers=response_headers, ) @@ -233,9 +223,7 @@ def _create_event_data(self, event_message: EventMessage) -> dict[str, str]: """Create event data dictionary from an EventMessage.""" event_data = { "event": "message", - "data": event_message.message.model_dump_json( - by_alias=True, exclude_none=True - ), + "data": event_message.message.model_dump_json(by_alias=True, exclude_none=True), } # If an event ID was provided, include it @@ -283,42 +271,29 @@ def _check_accept_headers(self, request: Request) -> tuple[bool, bool]: accept_header = request.headers.get("accept", "") accept_types = [media_type.strip() for media_type in accept_header.split(",")] - has_json = any( - media_type.startswith(CONTENT_TYPE_JSON) for media_type in accept_types - ) - has_sse = any( - media_type.startswith(CONTENT_TYPE_SSE) for media_type in accept_types - ) + has_json = any(media_type.startswith(CONTENT_TYPE_JSON) for media_type in accept_types) + has_sse = any(media_type.startswith(CONTENT_TYPE_SSE) for media_type in accept_types) return has_json, has_sse def _check_content_type(self, request: Request) -> bool: """Check if the request has the correct Content-Type.""" content_type = request.headers.get("content-type", "") - content_type_parts = [ - part.strip() for part in content_type.split(";")[0].split(",") - ] + content_type_parts = [part.strip() for part in content_type.split(";")[0].split(",")] return any(part == CONTENT_TYPE_JSON for part in content_type_parts) - async def _handle_post_request( - self, scope: Scope, request: Request, receive: Receive, send: Send - ) -> None: + async def _handle_post_request(self, scope: Scope, request: Request, receive: Receive, send: Send) -> None: """Handle POST requests containing JSON-RPC messages.""" writer = self._read_stream_writer if writer is None: - raise ValueError( - "No read stream writer available. Ensure connect() is called first." - ) + raise ValueError("No read stream writer available. Ensure connect() is called first.") try: # Check Accept headers has_json, has_sse = self._check_accept_headers(request) if not (has_json and has_sse): response = self._create_error_response( - ( - "Not Acceptable: Client must accept both application/json and " - "text/event-stream" - ), + ("Not Acceptable: Client must accept both application/json and " "text/event-stream"), HTTPStatus.NOT_ACCEPTABLE, ) await response(scope, receive, send) @@ -346,9 +321,7 @@ async def _handle_post_request( try: raw_message = json.loads(body) except json.JSONDecodeError as e: - response = self._create_error_response( - f"Parse error: {str(e)}", HTTPStatus.BAD_REQUEST, PARSE_ERROR - ) + response = self._create_error_response(f"Parse error: {str(e)}", HTTPStatus.BAD_REQUEST, PARSE_ERROR) await response(scope, receive, send) return @@ -364,10 +337,7 @@ async def _handle_post_request( return # Check if this is an initialization request - is_initialization_request = ( - isinstance(message.root, JSONRPCRequest) - and message.root.method == "initialize" - ) + is_initialization_request = isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize" if is_initialization_request: # Check if the server already has an established session @@ -406,9 +376,7 @@ async def _handle_post_request( # Extract the request ID outside the try block for proper scope request_id = str(message.root.id) # Register this stream for the request ID - self._request_streams[request_id] = anyio.create_memory_object_stream[ - EventMessage - ](0) + self._request_streams[request_id] = anyio.create_memory_object_stream[EventMessage](0) request_stream_reader = self._request_streams[request_id][1] if self.is_json_response_enabled: @@ -424,16 +392,12 @@ async def _handle_post_request( # Use similar approach to SSE writer for consistency async for event_message in request_stream_reader: # If it's a response, this is what we're waiting for - if isinstance( - event_message.message.root, JSONRPCResponse | JSONRPCError - ): + if isinstance(event_message.message.root, JSONRPCResponse | JSONRPCError): response_message = event_message.message break # For notifications and request, keep waiting else: - logger.debug( - f"received: {event_message.message.root.method}" - ) + logger.debug(f"received: {event_message.message.root.method}") # At this point we should have a response if response_message: @@ -442,9 +406,7 @@ async def _handle_post_request( await response(scope, receive, send) else: # This shouldn't happen in normal operation - logger.error( - "No response message received before stream closed" - ) + logger.error("No response message received before stream closed") response = self._create_error_response( "Error processing request: No response received", HTTPStatus.INTERNAL_SERVER_ERROR, @@ -462,9 +424,7 @@ async def _handle_post_request( await self._clean_up_memory_streams(request_id) else: # Create SSE stream - sse_stream_writer, sse_stream_reader = ( - anyio.create_memory_object_stream[dict[str, str]](0) - ) + sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0) async def sse_writer(): # Get the request ID from the incoming request message @@ -495,11 +455,7 @@ async def sse_writer(): "Cache-Control": "no-cache, no-transform", "Connection": "keep-alive", "Content-Type": CONTENT_TYPE_SSE, - **( - {MCP_SESSION_ID_HEADER: self.mcp_session_id} - if self.mcp_session_id - else {} - ), + **({MCP_SESSION_ID_HEADER: self.mcp_session_id} if self.mcp_session_id else {}), } response = EventSourceResponse( content=sse_stream_reader, @@ -544,9 +500,7 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: """ writer = self._read_stream_writer if writer is None: - raise ValueError( - "No read stream writer available. Ensure connect() is called first." - ) + raise ValueError("No read stream writer available. Ensure connect() is called first.") # Validate Accept header - must include text/event-stream _, has_sse = self._check_accept_headers(request) @@ -585,17 +539,13 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: return # Create SSE stream - sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[ - dict[str, str] - ](0) + sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0) async def standalone_sse_writer(): try: # Create a standalone message stream for server-initiated messages - self._request_streams[GET_STREAM_KEY] = ( - anyio.create_memory_object_stream[EventMessage](0) - ) + self._request_streams[GET_STREAM_KEY] = anyio.create_memory_object_stream[EventMessage](0) standalone_stream_reader = self._request_streams[GET_STREAM_KEY][1] async with sse_stream_writer, standalone_stream_reader: @@ -732,9 +682,7 @@ async def _validate_session(self, request: Request, send: Send) -> bool: return True - async def _replay_events( - self, last_event_id: str, request: Request, send: Send - ) -> None: + async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None: """ Replays events that would have been sent after the specified event ID. Only used when resumability is enabled. @@ -754,9 +702,7 @@ async def _replay_events( headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id # Create SSE stream for replay - sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[ - dict[str, str] - ](0) + sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0) async def replay_sender(): try: @@ -767,15 +713,11 @@ async def send_event(event_message: EventMessage) -> None: await sse_stream_writer.send(event_data) # Replay past events and get the stream ID - stream_id = await event_store.replay_events_after( - last_event_id, send_event - ) + stream_id = await event_store.replay_events_after(last_event_id, send_event) # If stream ID not in mapping, create it if stream_id and stream_id not in self._request_streams: - self._request_streams[stream_id] = ( - anyio.create_memory_object_stream[EventMessage](0) - ) + self._request_streams[stream_id] = anyio.create_memory_object_stream[EventMessage](0) msg_reader = self._request_streams[stream_id][1] # Forward messages to SSE @@ -829,12 +771,8 @@ async def connect( # Create the memory streams for this connection - read_stream_writer, read_stream = anyio.create_memory_object_stream[ - SessionMessage | Exception - ](0) - write_stream, write_stream_reader = anyio.create_memory_object_stream[ - SessionMessage - ](0) + read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0) + write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0) # Store the streams self._read_stream_writer = read_stream_writer @@ -867,35 +805,24 @@ async def message_router(): session_message.metadata, ServerMessageMetadata, ) - and session_message.metadata.related_request_id - is not None + and session_message.metadata.related_request_id is not None ): - target_request_id = str( - session_message.metadata.related_request_id - ) + target_request_id = str(session_message.metadata.related_request_id) - request_stream_id = ( - target_request_id - if target_request_id is not None - else GET_STREAM_KEY - ) + request_stream_id = target_request_id if target_request_id is not None else GET_STREAM_KEY # Store the event if we have an event store, # regardless of whether a client is connected # messages will be replayed on the re-connect event_id = None if self._event_store: - event_id = await self._event_store.store_event( - request_stream_id, message - ) + event_id = await self._event_store.store_event(request_stream_id, message) logger.debug(f"Stored {event_id} from {request_stream_id}") if request_stream_id in self._request_streams: try: # Send both the message and the event ID - await self._request_streams[request_stream_id][0].send( - EventMessage(message, event_id) - ) + await self._request_streams[request_stream_id][0].send(EventMessage(message, event_id)) except ( anyio.BrokenResourceError, anyio.ClosedResourceError, diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index 8188c2f3b..3af9829d1 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -165,9 +165,7 @@ async def _handle_stateless_request( ) # Start server in a new task - async def run_stateless_server( - *, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED - ): + async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED): async with http_transport.connect() as streams: read_stream, write_stream = streams task_status.started() @@ -204,10 +202,7 @@ async def _handle_stateful_request( request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER) # Existing session case - if ( - request_mcp_session_id is not None - and request_mcp_session_id in self._server_instances - ): + if request_mcp_session_id is not None and request_mcp_session_id in self._server_instances: transport = self._server_instances[request_mcp_session_id] logger.debug("Session already exists, handling request directly") await transport.handle_request(scope, receive, send) @@ -229,9 +224,7 @@ async def _handle_stateful_request( logger.info(f"Created new transport with session ID: {new_session_id}") # Define the server runner - async def run_server( - *, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED - ) -> None: + async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED) -> None: async with http_transport.connect() as streams: read_stream, write_stream = streams task_status.started() diff --git a/src/mcp/server/streaming_asgi_transport.py b/src/mcp/server/streaming_asgi_transport.py index 54a2fdb8c..a74751312 100644 --- a/src/mcp/server/streaming_asgi_transport.py +++ b/src/mcp/server/streaming_asgi_transport.py @@ -93,12 +93,8 @@ async def handle_async_request( initial_response_ready = anyio.Event() # Synchronization for streaming response - asgi_send_channel, asgi_receive_channel = anyio.create_memory_object_stream[ - dict[str, Any] - ](100) - content_send_channel, content_receive_channel = ( - anyio.create_memory_object_stream[bytes](100) - ) + asgi_send_channel, asgi_receive_channel = anyio.create_memory_object_stream[dict[str, Any]](100) + content_send_channel, content_receive_channel = anyio.create_memory_object_stream[bytes](100) # ASGI callables. async def receive() -> dict[str, Any]: @@ -124,21 +120,15 @@ async def send(message: dict[str, Any]) -> None: async def run_app() -> None: try: # Cast the receive and send functions to the ASGI types - await self.app( - cast(Scope, scope), cast(Receive, receive), cast(Send, send) - ) + await self.app(cast(Scope, scope), cast(Receive, receive), cast(Send, send)) except Exception: if self.raise_app_exceptions: raise if not response_started: - await asgi_send_channel.send( - {"type": "http.response.start", "status": 500, "headers": []} - ) + await asgi_send_channel.send({"type": "http.response.start", "status": 500, "headers": []}) - await asgi_send_channel.send( - {"type": "http.response.body", "body": b"", "more_body": False} - ) + await asgi_send_channel.send({"type": "http.response.body", "body": b"", "more_body": False}) finally: await asgi_send_channel.aclose() diff --git a/src/mcp/server/websocket.py b/src/mcp/server/websocket.py index 9dc3f2a25..7c0d8789c 100644 --- a/src/mcp/server/websocket.py +++ b/src/mcp/server/websocket.py @@ -51,9 +51,7 @@ async def ws_writer(): try: async with write_stream_reader: async for session_message in write_stream_reader: - obj = session_message.message.model_dump_json( - by_alias=True, exclude_none=True - ) + obj = session_message.message.model_dump_json(by_alias=True, exclude_none=True) await websocket.send_text(obj) except anyio.ClosedResourceError: await websocket.close() diff --git a/src/mcp/shared/auth.py b/src/mcp/shared/auth.py index 22f8a971d..e5dbff439 100644 --- a/src/mcp/shared/auth.py +++ b/src/mcp/shared/auth.py @@ -36,9 +36,7 @@ class OAuthClientMetadata(BaseModel): # token_endpoint_auth_method: this implementation only supports none & # client_secret_post; # ie: we do not support client_secret_basic - token_endpoint_auth_method: Literal["none", "client_secret_post"] = ( - "client_secret_post" - ) + token_endpoint_auth_method: Literal["none", "client_secret_post"] = "client_secret_post" # grant_types: this implementation only supports authorization_code & refresh_token grant_types: list[Literal["authorization_code", "refresh_token"]] = [ "authorization_code", @@ -75,17 +73,12 @@ def validate_redirect_uri(self, redirect_uri: AnyHttpUrl | None) -> AnyHttpUrl: if redirect_uri is not None: # Validate redirect_uri against client's registered redirect URIs if redirect_uri not in self.redirect_uris: - raise InvalidRedirectUriError( - f"Redirect URI '{redirect_uri}' not registered for client" - ) + raise InvalidRedirectUriError(f"Redirect URI '{redirect_uri}' not registered for client") return redirect_uri elif len(self.redirect_uris) == 1: return self.redirect_uris[0] else: - raise InvalidRedirectUriError( - "redirect_uri must be specified when client " - "has multiple registered URIs" - ) + raise InvalidRedirectUriError("redirect_uri must be specified when client " "has multiple registered URIs") class OAuthClientInformationFull(OAuthClientMetadata): @@ -113,25 +106,17 @@ class OAuthMetadata(BaseModel): scopes_supported: list[str] | None = None response_types_supported: list[Literal["code"]] = ["code"] response_modes_supported: list[Literal["query", "fragment"]] | None = None - grant_types_supported: ( - list[Literal["authorization_code", "refresh_token"]] | None - ) = None - token_endpoint_auth_methods_supported: ( - list[Literal["none", "client_secret_post"]] | None - ) = None + grant_types_supported: list[Literal["authorization_code", "refresh_token"]] | None = None + token_endpoint_auth_methods_supported: list[Literal["none", "client_secret_post"]] | None = None token_endpoint_auth_signing_alg_values_supported: None = None service_documentation: AnyHttpUrl | None = None ui_locales_supported: list[str] | None = None op_policy_uri: AnyHttpUrl | None = None op_tos_uri: AnyHttpUrl | None = None revocation_endpoint: AnyHttpUrl | None = None - revocation_endpoint_auth_methods_supported: ( - list[Literal["client_secret_post"]] | None - ) = None + revocation_endpoint_auth_methods_supported: list[Literal["client_secret_post"]] | None = None revocation_endpoint_auth_signing_alg_values_supported: None = None introspection_endpoint: AnyHttpUrl | None = None - introspection_endpoint_auth_methods_supported: ( - list[Literal["client_secret_post"]] | None - ) = None + introspection_endpoint_auth_methods_supported: list[Literal["client_secret_post"]] | None = None introspection_endpoint_auth_signing_alg_values_supported: None = None code_challenge_methods_supported: list[Literal["S256"]] | None = None diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index b53f8dd63..f088d3f8b 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -11,26 +11,15 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream import mcp.types as types -from mcp.client.session import ( - ClientSession, - ListRootsFnT, - LoggingFnT, - MessageHandlerFnT, - SamplingFnT, -) +from mcp.client.session import ClientSession, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT from mcp.server import Server from mcp.shared.message import SessionMessage -MessageStream = tuple[ - MemoryObjectReceiveStream[SessionMessage | Exception], - MemoryObjectSendStream[SessionMessage], -] +MessageStream = tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]] @asynccontextmanager -async def create_client_server_memory_streams() -> ( - AsyncGenerator[tuple[MessageStream, MessageStream], None] -): +async def create_client_server_memory_streams() -> AsyncGenerator[tuple[MessageStream, MessageStream], None]: """ Creates a pair of bidirectional memory streams for client-server communication. @@ -39,12 +28,8 @@ async def create_client_server_memory_streams() -> ( (read_stream, write_stream) """ # Create streams for both directions - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - SessionMessage | Exception - ](1) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - SessionMessage | Exception - ](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) client_streams = (server_to_client_receive, client_to_server_send) server_streams = (client_to_server_receive, server_to_client_send) diff --git a/src/mcp/shared/message.py b/src/mcp/shared/message.py index 6b0233714..4b6df23eb 100644 --- a/src/mcp/shared/message.py +++ b/src/mcp/shared/message.py @@ -20,9 +20,7 @@ class ClientMessageMetadata: """Metadata specific to client messages.""" resumption_token: ResumptionToken | None = None - on_resumption_token_update: Callable[[ResumptionToken], Awaitable[None]] | None = ( - None - ) + on_resumption_token_update: Callable[[ResumptionToken], Awaitable[None]] | None = None @dataclass diff --git a/src/mcp/shared/progress.py b/src/mcp/shared/progress.py index 856a8d3b6..1ad81a779 100644 --- a/src/mcp/shared/progress.py +++ b/src/mcp/shared/progress.py @@ -23,22 +23,8 @@ class Progress(BaseModel): @dataclass -class ProgressContext( - Generic[ - SendRequestT, - SendNotificationT, - SendResultT, - ReceiveRequestT, - ReceiveNotificationT, - ] -): - session: BaseSession[ - SendRequestT, - SendNotificationT, - SendResultT, - ReceiveRequestT, - ReceiveNotificationT, - ] +class ProgressContext(Generic[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT]): + session: BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT] progress_token: ProgressToken total: float | None current: float = field(default=0.0, init=False) @@ -54,24 +40,12 @@ async def progress(self, amount: float, message: str | None = None) -> None: @contextmanager def progress( ctx: RequestContext[ - BaseSession[ - SendRequestT, - SendNotificationT, - SendResultT, - ReceiveRequestT, - ReceiveNotificationT, - ], + BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT], LifespanContextT, ], total: float | None = None, ) -> Generator[ - ProgressContext[ - SendRequestT, - SendNotificationT, - SendResultT, - ReceiveRequestT, - ReceiveNotificationT, - ], + ProgressContext[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT], None, ]: if ctx.meta is None or ctx.meta.progressToken is None: diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 4b13709c6..7f813f2e1 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -37,9 +37,7 @@ SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification) ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest) ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel) -ReceiveNotificationT = TypeVar( - "ReceiveNotificationT", ClientNotification, ServerNotification -) +ReceiveNotificationT = TypeVar("ReceiveNotificationT", ClientNotification, ServerNotification) RequestId = str | int @@ -47,9 +45,7 @@ class ProgressFnT(Protocol): """Protocol for progress notification callbacks.""" - async def __call__( - self, progress: float, total: float | None, message: str | None - ) -> None: ... + async def __call__(self, progress: float, total: float | None, message: str | None) -> None: ... class RequestResponder(Generic[ReceiveRequestT, SendResultT]): @@ -176,9 +172,7 @@ class BaseSession( messages when entered. """ - _response_streams: dict[ - RequestId, MemoryObjectSendStream[JSONRPCResponse | JSONRPCError] - ] + _response_streams: dict[RequestId, MemoryObjectSendStream[JSONRPCResponse | JSONRPCError]] _request_id: int _in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]] _progress_callbacks: dict[RequestId, ProgressFnT] @@ -241,9 +235,7 @@ async def send_request( request_id = self._request_id self._request_id = request_id + 1 - response_stream, response_stream_reader = anyio.create_memory_object_stream[ - JSONRPCResponse | JSONRPCError - ](1) + response_stream, response_stream_reader = anyio.create_memory_object_stream[JSONRPCResponse | JSONRPCError](1) self._response_streams[request_id] = response_stream # Set up progress token if progress callback is provided @@ -265,11 +257,7 @@ async def send_request( **request_data, ) - await self._write_stream.send( - SessionMessage( - message=JSONRPCMessage(jsonrpc_request), metadata=metadata - ) - ) + await self._write_stream.send(SessionMessage(message=JSONRPCMessage(jsonrpc_request), metadata=metadata)) # request read timeout takes precedence over session read timeout timeout = None @@ -321,15 +309,11 @@ async def send_notification( ) session_message = SessionMessage( message=JSONRPCMessage(jsonrpc_notification), - metadata=ServerMessageMetadata(related_request_id=related_request_id) - if related_request_id - else None, + metadata=ServerMessageMetadata(related_request_id=related_request_id) if related_request_id else None, ) await self._write_stream.send(session_message) - async def _send_response( - self, request_id: RequestId, response: SendResultT | ErrorData - ) -> None: + async def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData) -> None: if isinstance(response, ErrorData): jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response) session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error)) @@ -338,9 +322,7 @@ async def _send_response( jsonrpc_response = JSONRPCResponse( jsonrpc="2.0", id=request_id, - result=response.model_dump( - by_alias=True, mode="json", exclude_none=True - ), + result=response.model_dump(by_alias=True, mode="json", exclude_none=True), ) session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response)) await self._write_stream.send(session_message) @@ -355,15 +337,11 @@ async def _receive_loop(self) -> None: await self._handle_incoming(message) elif isinstance(message.message.root, JSONRPCRequest): validated_request = self._receive_request_type.model_validate( - message.message.root.model_dump( - by_alias=True, mode="json", exclude_none=True - ) + message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True) ) responder = RequestResponder( request_id=message.message.root.id, - request_meta=validated_request.root.params.meta - if validated_request.root.params - else None, + request_meta=validated_request.root.params.meta if validated_request.root.params else None, request=validated_request, session=self, on_complete=lambda r: self._in_flight.pop(r.request_id, None), @@ -379,9 +357,7 @@ async def _receive_loop(self) -> None: elif isinstance(message.message.root, JSONRPCNotification): try: notification = self._receive_notification_type.model_validate( - message.message.root.model_dump( - by_alias=True, mode="json", exclude_none=True - ) + message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True) ) # Handle cancellation notifications if isinstance(notification.root, CancelledNotification): @@ -406,8 +382,7 @@ async def _receive_loop(self) -> None: except Exception as e: # For other validation errors, log and continue logging.warning( - f"Failed to validate notification: {e}. " - f"Message was: {message.message.root}" + f"Failed to validate notification: {e}. " f"Message was: {message.message.root}" ) else: # Response or error stream = self._response_streams.pop(message.message.root.id, None) @@ -415,10 +390,7 @@ async def _receive_loop(self) -> None: await stream.send(message.message.root) else: await self._handle_incoming( - RuntimeError( - "Received response with an unknown " - f"request ID: {message}" - ) + RuntimeError("Received response with an unknown " f"request ID: {message}") ) # after the read stream is closed, we need to send errors @@ -429,9 +401,7 @@ async def _receive_loop(self) -> None: await stream.aclose() self._response_streams.clear() - async def _received_request( - self, responder: RequestResponder[ReceiveRequestT, SendResultT] - ) -> None: + async def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None: """ Can be overridden by subclasses to handle a request without needing to listen on the message stream. @@ -460,9 +430,7 @@ async def send_progress_notification( async def _handle_incoming( self, - req: RequestResponder[ReceiveRequestT, SendResultT] - | ReceiveNotificationT - | Exception, + req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception, ) -> None: """A generic handler for incoming messages. Overwritten by subclasses.""" pass diff --git a/src/mcp/types.py b/src/mcp/types.py index 4f5af27b9..a0d1a67a3 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -1,12 +1,5 @@ from collections.abc import Callable -from typing import ( - Annotated, - Any, - Generic, - Literal, - TypeAlias, - TypeVar, -) +from typing import Annotated, Any, Generic, Literal, TypeAlias, TypeVar from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel from pydantic.networks import AnyUrl, UrlConstraints @@ -73,9 +66,7 @@ class Meta(BaseModel): RequestParamsT = TypeVar("RequestParamsT", bound=RequestParams | dict[str, Any] | None) -NotificationParamsT = TypeVar( - "NotificationParamsT", bound=NotificationParams | dict[str, Any] | None -) +NotificationParamsT = TypeVar("NotificationParamsT", bound=NotificationParams | dict[str, Any] | None) MethodT = TypeVar("MethodT", bound=str) @@ -87,9 +78,7 @@ class Request(BaseModel, Generic[RequestParamsT, MethodT]): model_config = ConfigDict(extra="allow") -class PaginatedRequest( - Request[PaginatedRequestParams | None, MethodT], Generic[MethodT] -): +class PaginatedRequest(Request[PaginatedRequestParams | None, MethodT], Generic[MethodT]): """Base class for paginated requests, matching the schema's PaginatedRequest interface.""" @@ -191,9 +180,7 @@ class JSONRPCError(BaseModel): model_config = ConfigDict(extra="allow") -class JSONRPCMessage( - RootModel[JSONRPCRequest | JSONRPCNotification | JSONRPCResponse | JSONRPCError] -): +class JSONRPCMessage(RootModel[JSONRPCRequest | JSONRPCNotification | JSONRPCResponse | JSONRPCError]): pass @@ -314,9 +301,7 @@ class InitializeResult(Result): """Instructions describing how to use the server and its features.""" -class InitializedNotification( - Notification[NotificationParams | None, Literal["notifications/initialized"]] -): +class InitializedNotification(Notification[NotificationParams | None, Literal["notifications/initialized"]]): """ This notification is sent from the client to the server after initialization has finished. @@ -351,7 +336,7 @@ class ProgressNotificationParams(NotificationParams): """ total: float | None = None """ - Message related to progress. This should provide relevant human readable + Message related to progress. This should provide relevant human readable progress information. """ message: str | None = None @@ -359,9 +344,7 @@ class ProgressNotificationParams(NotificationParams): model_config = ConfigDict(extra="allow") -class ProgressNotification( - Notification[ProgressNotificationParams, Literal["notifications/progress"]] -): +class ProgressNotification(Notification[ProgressNotificationParams, Literal["notifications/progress"]]): """ An out-of-band notification used to inform the receiver of a progress update for a long-running request. @@ -432,9 +415,7 @@ class ListResourcesResult(PaginatedResult): resources: list[Resource] -class ListResourceTemplatesRequest( - PaginatedRequest[Literal["resources/templates/list"]] -): +class ListResourceTemplatesRequest(PaginatedRequest[Literal["resources/templates/list"]]): """Sent from the client to request a list of resource templates the server has.""" method: Literal["resources/templates/list"] @@ -457,9 +438,7 @@ class ReadResourceRequestParams(RequestParams): model_config = ConfigDict(extra="allow") -class ReadResourceRequest( - Request[ReadResourceRequestParams, Literal["resources/read"]] -): +class ReadResourceRequest(Request[ReadResourceRequestParams, Literal["resources/read"]]): """Sent from the client to the server, to read a specific resource URI.""" method: Literal["resources/read"] @@ -500,9 +479,7 @@ class ReadResourceResult(Result): class ResourceListChangedNotification( - Notification[ - NotificationParams | None, Literal["notifications/resources/list_changed"] - ] + Notification[NotificationParams | None, Literal["notifications/resources/list_changed"]] ): """ An optional notification from the server to the client, informing it that the list @@ -542,9 +519,7 @@ class UnsubscribeRequestParams(RequestParams): model_config = ConfigDict(extra="allow") -class UnsubscribeRequest( - Request[UnsubscribeRequestParams, Literal["resources/unsubscribe"]] -): +class UnsubscribeRequest(Request[UnsubscribeRequestParams, Literal["resources/unsubscribe"]]): """ Sent from the client to request cancellation of resources/updated notifications from the server. @@ -566,9 +541,7 @@ class ResourceUpdatedNotificationParams(NotificationParams): class ResourceUpdatedNotification( - Notification[ - ResourceUpdatedNotificationParams, Literal["notifications/resources/updated"] - ] + Notification[ResourceUpdatedNotificationParams, Literal["notifications/resources/updated"]] ): """ A notification from the server to the client, informing it that a resource has @@ -696,9 +669,7 @@ class GetPromptResult(Result): class PromptListChangedNotification( - Notification[ - NotificationParams | None, Literal["notifications/prompts/list_changed"] - ] + Notification[NotificationParams | None, Literal["notifications/prompts/list_changed"]] ): """ An optional notification from the server to the client, informing it that the list @@ -805,9 +776,7 @@ class CallToolResult(Result): isError: bool = False -class ToolListChangedNotification( - Notification[NotificationParams | None, Literal["notifications/tools/list_changed"]] -): +class ToolListChangedNotification(Notification[NotificationParams | None, Literal["notifications/tools/list_changed"]]): """ An optional notification from the server to the client, informing it that the list of tools it offers has changed. @@ -817,9 +786,7 @@ class ToolListChangedNotification( params: NotificationParams | None = None -LoggingLevel = Literal[ - "debug", "info", "notice", "warning", "error", "critical", "alert", "emergency" -] +LoggingLevel = Literal["debug", "info", "notice", "warning", "error", "critical", "alert", "emergency"] class SetLevelRequestParams(RequestParams): @@ -852,9 +819,7 @@ class LoggingMessageNotificationParams(NotificationParams): model_config = ConfigDict(extra="allow") -class LoggingMessageNotification( - Notification[LoggingMessageNotificationParams, Literal["notifications/message"]] -): +class LoggingMessageNotification(Notification[LoggingMessageNotificationParams, Literal["notifications/message"]]): """Notification of a log message passed from server to client.""" method: Literal["notifications/message"] @@ -949,9 +914,7 @@ class CreateMessageRequestParams(RequestParams): model_config = ConfigDict(extra="allow") -class CreateMessageRequest( - Request[CreateMessageRequestParams, Literal["sampling/createMessage"]] -): +class CreateMessageRequest(Request[CreateMessageRequestParams, Literal["sampling/createMessage"]]): """A request from the server to sample an LLM via the client.""" method: Literal["sampling/createMessage"] @@ -1108,9 +1071,7 @@ class CancelledNotificationParams(NotificationParams): model_config = ConfigDict(extra="allow") -class CancelledNotification( - Notification[CancelledNotificationParams, Literal["notifications/cancelled"]] -): +class CancelledNotification(Notification[CancelledNotificationParams, Literal["notifications/cancelled"]]): """ This notification can be sent by either side to indicate that it is canceling a previously-issued request. @@ -1141,12 +1102,7 @@ class ClientRequest( class ClientNotification( - RootModel[ - CancelledNotification - | ProgressNotification - | InitializedNotification - | RootsListChangedNotification - ] + RootModel[CancelledNotification | ProgressNotification | InitializedNotification | RootsListChangedNotification] ): pass diff --git a/tests/client/conftest.py b/tests/client/conftest.py index 60ccac743..0c8283903 100644 --- a/tests/client/conftest.py +++ b/tests/client/conftest.py @@ -49,8 +49,7 @@ def get_client_requests(self, method: str | None = None) -> list[JSONRPCRequest] return [ req.message.root for req in self.client.sent_messages - if isinstance(req.message.root, JSONRPCRequest) - and (method is None or req.message.root.method == method) + if isinstance(req.message.root, JSONRPCRequest) and (method is None or req.message.root.method == method) ] def get_server_requests(self, method: str | None = None) -> list[JSONRPCRequest]: @@ -58,13 +57,10 @@ def get_server_requests(self, method: str | None = None) -> list[JSONRPCRequest] return [ req.message.root for req in self.server.sent_messages - if isinstance(req.message.root, JSONRPCRequest) - and (method is None or req.message.root.method == method) + if isinstance(req.message.root, JSONRPCRequest) and (method is None or req.message.root.method == method) ] - def get_client_notifications( - self, method: str | None = None - ) -> list[JSONRPCNotification]: + def get_client_notifications(self, method: str | None = None) -> list[JSONRPCNotification]: """Get client-sent notifications, optionally filtered by method.""" return [ notif.message.root @@ -73,9 +69,7 @@ def get_client_notifications( and (method is None or notif.message.root.method == method) ] - def get_server_notifications( - self, method: str | None = None - ) -> list[JSONRPCNotification]: + def get_server_notifications(self, method: str | None = None) -> list[JSONRPCNotification]: """Get server-sent notifications, optionally filtered by method.""" return [ notif.message.root @@ -133,9 +127,7 @@ async def patched_create_streams(): yield (client_read, spy_client_write), (server_read, spy_server_write) # Apply the patch for the duration of the test - with patch( - "mcp.shared.memory.create_client_server_memory_streams", patched_create_streams - ): + with patch("mcp.shared.memory.create_client_server_memory_streams", patched_create_streams): # Return a collection with helper methods def get_spy_collection() -> StreamSpyCollection: assert client_spy is not None, "client_spy was not initialized" diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 2edaff946..cb344f17f 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -134,9 +134,7 @@ def test_generate_code_verifier(self, oauth_provider): assert len(verifier) == 128 # Check charset (RFC 7636: A-Z, a-z, 0-9, "-", ".", "_", "~") - allowed_chars = set( - "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~" - ) + allowed_chars = set("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~") assert set(verifier) <= allowed_chars # Check uniqueness (generate multiple and ensure they're different) @@ -150,9 +148,7 @@ def test_generate_code_challenge(self, oauth_provider): # Manually calculate expected challenge expected_digest = hashlib.sha256(verifier.encode()).digest() - expected_challenge = ( - base64.urlsafe_b64encode(expected_digest).decode().rstrip("=") - ) + expected_challenge = base64.urlsafe_b64encode(expected_digest).decode().rstrip("=") assert challenge == expected_challenge @@ -164,29 +160,19 @@ def test_generate_code_challenge(self, oauth_provider): def test_get_authorization_base_url(self, oauth_provider): """Test authorization base URL extraction.""" # Test with path - assert ( - oauth_provider._get_authorization_base_url("https://api.example.com/v1/mcp") - == "https://api.example.com" - ) + assert oauth_provider._get_authorization_base_url("https://api.example.com/v1/mcp") == "https://api.example.com" # Test with no path - assert ( - oauth_provider._get_authorization_base_url("https://api.example.com") - == "https://api.example.com" - ) + assert oauth_provider._get_authorization_base_url("https://api.example.com") == "https://api.example.com" # Test with port assert ( - oauth_provider._get_authorization_base_url( - "https://api.example.com:8080/path/to/mcp" - ) + oauth_provider._get_authorization_base_url("https://api.example.com:8080/path/to/mcp") == "https://api.example.com:8080" ) @pytest.mark.anyio - async def test_discover_oauth_metadata_success( - self, oauth_provider, oauth_metadata - ): + async def test_discover_oauth_metadata_success(self, oauth_provider, oauth_metadata): """Test successful OAuth metadata discovery.""" metadata_response = oauth_metadata.model_dump(by_alias=True, mode="json") @@ -199,23 +185,16 @@ async def test_discover_oauth_metadata_success( mock_response.json.return_value = metadata_response mock_client.get.return_value = mock_response - result = await oauth_provider._discover_oauth_metadata( - "https://api.example.com/v1/mcp" - ) + result = await oauth_provider._discover_oauth_metadata("https://api.example.com/v1/mcp") assert result is not None - assert ( - result.authorization_endpoint == oauth_metadata.authorization_endpoint - ) + assert result.authorization_endpoint == oauth_metadata.authorization_endpoint assert result.token_endpoint == oauth_metadata.token_endpoint # Verify correct URL was called mock_client.get.assert_called_once() call_args = mock_client.get.call_args[0] - assert ( - call_args[0] - == "https://api.example.com/.well-known/oauth-authorization-server" - ) + assert call_args[0] == "https://api.example.com/.well-known/oauth-authorization-server" @pytest.mark.anyio async def test_discover_oauth_metadata_not_found(self, oauth_provider): @@ -228,16 +207,12 @@ async def test_discover_oauth_metadata_not_found(self, oauth_provider): mock_response.status_code = 404 mock_client.get.return_value = mock_response - result = await oauth_provider._discover_oauth_metadata( - "https://api.example.com/v1/mcp" - ) + result = await oauth_provider._discover_oauth_metadata("https://api.example.com/v1/mcp") assert result is None @pytest.mark.anyio - async def test_discover_oauth_metadata_cors_fallback( - self, oauth_provider, oauth_metadata - ): + async def test_discover_oauth_metadata_cors_fallback(self, oauth_provider, oauth_metadata): """Test OAuth metadata discovery with CORS fallback.""" metadata_response = oauth_metadata.model_dump(by_alias=True, mode="json") @@ -255,17 +230,13 @@ async def test_discover_oauth_metadata_cors_fallback( mock_response_success, # Second call succeeds ] - result = await oauth_provider._discover_oauth_metadata( - "https://api.example.com/v1/mcp" - ) + result = await oauth_provider._discover_oauth_metadata("https://api.example.com/v1/mcp") assert result is not None assert mock_client.get.call_count == 2 @pytest.mark.anyio - async def test_register_oauth_client_success( - self, oauth_provider, oauth_metadata, oauth_client_info - ): + async def test_register_oauth_client_success(self, oauth_provider, oauth_metadata, oauth_client_info): """Test successful OAuth client registration.""" registration_response = oauth_client_info.model_dump(by_alias=True, mode="json") @@ -293,9 +264,7 @@ async def test_register_oauth_client_success( assert call_args[0][0] == str(oauth_metadata.registration_endpoint) @pytest.mark.anyio - async def test_register_oauth_client_fallback_endpoint( - self, oauth_provider, oauth_client_info - ): + async def test_register_oauth_client_fallback_endpoint(self, oauth_provider, oauth_client_info): """Test OAuth client registration with fallback endpoint.""" registration_response = oauth_client_info.model_dump(by_alias=True, mode="json") @@ -309,9 +278,7 @@ async def test_register_oauth_client_fallback_endpoint( mock_client.post.return_value = mock_response # Mock metadata discovery to return None (fallback) - with patch.object( - oauth_provider, "_discover_oauth_metadata", return_value=None - ): + with patch.object(oauth_provider, "_discover_oauth_metadata", return_value=None): result = await oauth_provider._register_oauth_client( "https://api.example.com/v1/mcp", oauth_provider.client_metadata, @@ -338,9 +305,7 @@ async def test_register_oauth_client_failure(self, oauth_provider): mock_client.post.return_value = mock_response # Mock metadata discovery to return None (fallback) - with patch.object( - oauth_provider, "_discover_oauth_metadata", return_value=None - ): + with patch.object(oauth_provider, "_discover_oauth_metadata", return_value=None): with pytest.raises(httpx.HTTPStatusError): await oauth_provider._register_oauth_client( "https://api.example.com/v1/mcp", @@ -402,9 +367,7 @@ async def test_validate_token_scopes_subset(self, oauth_provider, client_metadat await oauth_provider._validate_token_scopes(token) @pytest.mark.anyio - async def test_validate_token_scopes_unauthorized( - self, oauth_provider, client_metadata - ): + async def test_validate_token_scopes_unauthorized(self, oauth_provider, client_metadata): """Test scope validation with unauthorized scopes.""" oauth_provider.client_metadata = client_metadata token = OAuthToken( @@ -432,9 +395,7 @@ async def test_validate_token_scopes_no_requested(self, oauth_provider): await oauth_provider._validate_token_scopes(token) @pytest.mark.anyio - async def test_initialize( - self, oauth_provider, mock_storage, oauth_token, oauth_client_info - ): + async def test_initialize(self, oauth_provider, mock_storage, oauth_token, oauth_client_info): """Test initialization loading from storage.""" mock_storage._tokens = oauth_token mock_storage._client_info = oauth_client_info @@ -445,9 +406,7 @@ async def test_initialize( assert oauth_provider._client_info == oauth_client_info @pytest.mark.anyio - async def test_get_or_register_client_existing( - self, oauth_provider, oauth_client_info - ): + async def test_get_or_register_client_existing(self, oauth_provider, oauth_client_info): """Test getting existing client info.""" oauth_provider._client_info = oauth_client_info @@ -456,13 +415,9 @@ async def test_get_or_register_client_existing( assert result == oauth_client_info @pytest.mark.anyio - async def test_get_or_register_client_register_new( - self, oauth_provider, oauth_client_info - ): + async def test_get_or_register_client_register_new(self, oauth_provider, oauth_client_info): """Test registering new client.""" - with patch.object( - oauth_provider, "_register_oauth_client", return_value=oauth_client_info - ) as mock_register: + with patch.object(oauth_provider, "_register_oauth_client", return_value=oauth_client_info) as mock_register: result = await oauth_provider._get_or_register_client() assert result == oauth_client_info @@ -470,9 +425,7 @@ async def test_get_or_register_client_register_new( mock_register.assert_called_once() @pytest.mark.anyio - async def test_exchange_code_for_token_success( - self, oauth_provider, oauth_client_info, oauth_token - ): + async def test_exchange_code_for_token_success(self, oauth_provider, oauth_client_info, oauth_token): """Test successful code exchange for token.""" oauth_provider._code_verifier = "test_verifier" token_response = oauth_token.model_dump(by_alias=True, mode="json") @@ -486,23 +439,14 @@ async def test_exchange_code_for_token_success( mock_response.json.return_value = token_response mock_client.post.return_value = mock_response - with patch.object( - oauth_provider, "_validate_token_scopes" - ) as mock_validate: - await oauth_provider._exchange_code_for_token( - "test_auth_code", oauth_client_info - ) + with patch.object(oauth_provider, "_validate_token_scopes") as mock_validate: + await oauth_provider._exchange_code_for_token("test_auth_code", oauth_client_info) - assert ( - oauth_provider._current_tokens.access_token - == oauth_token.access_token - ) + assert oauth_provider._current_tokens.access_token == oauth_token.access_token mock_validate.assert_called_once() @pytest.mark.anyio - async def test_exchange_code_for_token_failure( - self, oauth_provider, oauth_client_info - ): + async def test_exchange_code_for_token_failure(self, oauth_provider, oauth_client_info): """Test failed code exchange for token.""" oauth_provider._code_verifier = "test_verifier" @@ -516,14 +460,10 @@ async def test_exchange_code_for_token_failure( mock_client.post.return_value = mock_response with pytest.raises(Exception, match="Token exchange failed"): - await oauth_provider._exchange_code_for_token( - "invalid_auth_code", oauth_client_info - ) + await oauth_provider._exchange_code_for_token("invalid_auth_code", oauth_client_info) @pytest.mark.anyio - async def test_refresh_access_token_success( - self, oauth_provider, oauth_client_info, oauth_token - ): + async def test_refresh_access_token_success(self, oauth_provider, oauth_client_info, oauth_token): """Test successful token refresh.""" oauth_provider._current_tokens = oauth_token oauth_provider._client_info = oauth_client_info @@ -546,16 +486,11 @@ async def test_refresh_access_token_success( mock_response.json.return_value = token_response mock_client.post.return_value = mock_response - with patch.object( - oauth_provider, "_validate_token_scopes" - ) as mock_validate: + with patch.object(oauth_provider, "_validate_token_scopes") as mock_validate: result = await oauth_provider._refresh_access_token() assert result is True - assert ( - oauth_provider._current_tokens.access_token - == new_token.access_token - ) + assert oauth_provider._current_tokens.access_token == new_token.access_token mock_validate.assert_called_once() @pytest.mark.anyio @@ -571,9 +506,7 @@ async def test_refresh_access_token_no_refresh_token(self, oauth_provider): assert result is False @pytest.mark.anyio - async def test_refresh_access_token_failure( - self, oauth_provider, oauth_client_info, oauth_token - ): + async def test_refresh_access_token_failure(self, oauth_provider, oauth_client_info, oauth_token): """Test failed token refresh.""" oauth_provider._current_tokens = oauth_token oauth_provider._client_info = oauth_client_info @@ -590,9 +523,7 @@ async def test_refresh_access_token_failure( assert result is False @pytest.mark.anyio - async def test_perform_oauth_flow_success( - self, oauth_provider, oauth_metadata, oauth_client_info - ): + async def test_perform_oauth_flow_success(self, oauth_provider, oauth_metadata, oauth_client_info): """Test successful OAuth flow.""" oauth_provider._metadata = oauth_metadata oauth_provider._client_info = oauth_client_info @@ -636,9 +567,7 @@ async def mock_callback_handler() -> tuple[str, str | None]: mock_exchange.assert_called_once_with("test_auth_code", oauth_client_info) @pytest.mark.anyio - async def test_perform_oauth_flow_state_mismatch( - self, oauth_provider, oauth_metadata, oauth_client_info - ): + async def test_perform_oauth_flow_state_mismatch(self, oauth_provider, oauth_metadata, oauth_client_info): """Test OAuth flow with state parameter mismatch.""" oauth_provider._metadata = oauth_metadata oauth_provider._client_info = oauth_client_info @@ -674,9 +603,7 @@ async def test_ensure_token_refresh(self, oauth_provider, oauth_token): oauth_provider._current_tokens = oauth_token oauth_provider._token_expiry_time = time.time() - 3600 # Expired - with patch.object( - oauth_provider, "_refresh_access_token", return_value=True - ) as mock_refresh: + with patch.object(oauth_provider, "_refresh_access_token", return_value=True) as mock_refresh: await oauth_provider.ensure_token() mock_refresh.assert_called_once() @@ -703,10 +630,7 @@ async def test_async_auth_flow_add_token(self, oauth_provider, oauth_token): auth_flow = oauth_provider.async_auth_flow(request) updated_request = await auth_flow.__anext__() - assert ( - updated_request.headers["Authorization"] - == f"Bearer {oauth_token.access_token}" - ) + assert updated_request.headers["Authorization"] == f"Bearer {oauth_token.access_token}" # Send mock response try: @@ -756,9 +680,7 @@ async def test_async_auth_flow_no_token(self, oauth_provider): # No Authorization header should be added if no token assert "Authorization" not in updated_request.headers - def test_scope_priority_client_metadata_first( - self, oauth_provider, oauth_client_info - ): + def test_scope_priority_client_metadata_first(self, oauth_provider, oauth_client_info): """Test that client metadata scope takes priority.""" oauth_provider.client_metadata.scope = "read write" oauth_provider._client_info = oauth_client_info @@ -777,17 +699,12 @@ def test_scope_priority_client_metadata_first( # Apply scope logic from _perform_oauth_flow if oauth_provider.client_metadata.scope: auth_params["scope"] = oauth_provider.client_metadata.scope - elif ( - hasattr(oauth_provider._client_info, "scope") - and oauth_provider._client_info.scope - ): + elif hasattr(oauth_provider._client_info, "scope") and oauth_provider._client_info.scope: auth_params["scope"] = oauth_provider._client_info.scope assert auth_params["scope"] == "read write" - def test_scope_priority_no_client_metadata_scope( - self, oauth_provider, oauth_client_info - ): + def test_scope_priority_no_client_metadata_scope(self, oauth_provider, oauth_client_info): """Test that no scope parameter is set when client metadata has no scope.""" oauth_provider.client_metadata.scope = None oauth_provider._client_info = oauth_client_info @@ -831,10 +748,7 @@ async def test_scope_priority_no_scope(self, oauth_provider, oauth_client_info): # Apply scope logic from _perform_oauth_flow if oauth_provider.client_metadata.scope: auth_params["scope"] = oauth_provider.client_metadata.scope - elif ( - hasattr(oauth_provider._client_info, "scope") - and oauth_provider._client_info.scope - ): + elif hasattr(oauth_provider._client_info, "scope") and oauth_provider._client_info.scope: auth_params["scope"] = oauth_provider._client_info.scope # No scope should be set @@ -860,9 +774,7 @@ async def mock_redirect_handler(url: str) -> None: oauth_provider.redirect_handler = mock_redirect_handler # Patch secrets.compare_digest to verify it's being called - with patch( - "mcp.client.auth.secrets.compare_digest", return_value=False - ) as mock_compare: + with patch("mcp.client.auth.secrets.compare_digest", return_value=False) as mock_compare: with pytest.raises(Exception, match="State parameter mismatch"): await oauth_provider._perform_oauth_flow() @@ -870,9 +782,7 @@ async def mock_redirect_handler(url: str) -> None: mock_compare.assert_called_once() @pytest.mark.anyio - async def test_state_parameter_validation_none_state( - self, oauth_provider, oauth_metadata, oauth_client_info - ): + async def test_state_parameter_validation_none_state(self, oauth_provider, oauth_metadata, oauth_client_info): """Test that None state is handled correctly.""" oauth_provider._metadata = oauth_metadata oauth_provider._client_info = oauth_client_info @@ -907,9 +817,7 @@ async def test_token_exchange_error_basic(self, oauth_provider, oauth_client_inf mock_client.post.return_value = mock_response with pytest.raises(Exception, match="Token exchange failed"): - await oauth_provider._exchange_code_for_token( - "invalid_auth_code", oauth_client_info - ) + await oauth_provider._exchange_code_for_token("invalid_auth_code", oauth_client_info) @pytest.mark.parametrize( @@ -962,9 +870,7 @@ def test_build_metadata( metadata = build_metadata( issuer_url=AnyHttpUrl(issuer_url), service_documentation_url=AnyHttpUrl(service_documentation_url), - client_registration_options=ClientRegistrationOptions( - enabled=True, valid_scopes=["read", "write", "admin"] - ), + client_registration_options=ClientRegistrationOptions(enabled=True, valid_scopes=["read", "write", "admin"]), revocation_options=RevocationOptions(enabled=True), ) diff --git a/tests/client/test_config.py b/tests/client/test_config.py index 69efb4024..f144dcffb 100644 --- a/tests/client/test_config.py +++ b/tests/client/test_config.py @@ -44,9 +44,7 @@ def test_command_execution(mock_config_path: Path): test_args = [command] + args + ["--help"] - result = subprocess.run( - test_args, capture_output=True, text=True, timeout=5, check=False - ) + result = subprocess.run(test_args, capture_output=True, text=True, timeout=5, check=False) assert result.returncode == 0 assert "usage" in result.stdout.lower() diff --git a/tests/client/test_list_methods_cursor.py b/tests/client/test_list_methods_cursor.py index a6df7ec7e..f7b031737 100644 --- a/tests/client/test_list_methods_cursor.py +++ b/tests/client/test_list_methods_cursor.py @@ -182,9 +182,7 @@ async def test_template(name: str) -> str: # Test without cursor parameter (omitted) _ = await client_session.list_resource_templates() - list_templates_requests = spies.get_client_requests( - method="resources/templates/list" - ) + list_templates_requests = spies.get_client_requests(method="resources/templates/list") assert len(list_templates_requests) == 1 assert list_templates_requests[0].params is None @@ -192,9 +190,7 @@ async def test_template(name: str) -> str: # Test with cursor=None _ = await client_session.list_resource_templates(cursor=None) - list_templates_requests = spies.get_client_requests( - method="resources/templates/list" - ) + list_templates_requests = spies.get_client_requests(method="resources/templates/list") assert len(list_templates_requests) == 1 assert list_templates_requests[0].params is None @@ -202,9 +198,7 @@ async def test_template(name: str) -> str: # Test with cursor as string _ = await client_session.list_resource_templates(cursor="some_cursor") - list_templates_requests = spies.get_client_requests( - method="resources/templates/list" - ) + list_templates_requests = spies.get_client_requests(method="resources/templates/list") assert len(list_templates_requests) == 1 assert list_templates_requests[0].params is not None assert list_templates_requests[0].params["cursor"] == "some_cursor" @@ -213,9 +207,7 @@ async def test_template(name: str) -> str: # Test with empty string cursor _ = await client_session.list_resource_templates(cursor="") - list_templates_requests = spies.get_client_requests( - method="resources/templates/list" - ) + list_templates_requests = spies.get_client_requests(method="resources/templates/list") assert len(list_templates_requests) == 1 assert list_templates_requests[0].params is not None assert list_templates_requests[0].params["cursor"] == "" diff --git a/tests/client/test_list_roots_callback.py b/tests/client/test_list_roots_callback.py index f5b598218..f65490421 100644 --- a/tests/client/test_list_roots_callback.py +++ b/tests/client/test_list_roots_callback.py @@ -41,13 +41,9 @@ async def test_list_roots(context: Context, message: str): # type: ignore[repor return True # Test with list_roots callback - async with create_session( - server._mcp_server, list_roots_callback=list_roots_callback - ) as client_session: + async with create_session(server._mcp_server, list_roots_callback=list_roots_callback) as client_session: # Make a request to trigger sampling callback - result = await client_session.call_tool( - "test_list_roots", {"message": "test message"} - ) + result = await client_session.call_tool("test_list_roots", {"message": "test message"}) assert result.isError is False assert isinstance(result.content[0], TextContent) assert result.content[0].text == "true" @@ -55,12 +51,7 @@ async def test_list_roots(context: Context, message: str): # type: ignore[repor # Test without list_roots callback async with create_session(server._mcp_server) as client_session: # Make a request to trigger sampling callback - result = await client_session.call_tool( - "test_list_roots", {"message": "test message"} - ) + result = await client_session.call_tool("test_list_roots", {"message": "test message"}) assert result.isError is True assert isinstance(result.content[0], TextContent) - assert ( - result.content[0].text - == "Error executing tool test_list_roots: List roots not supported" - ) + assert result.content[0].text == "Error executing tool test_list_roots: List roots not supported" diff --git a/tests/client/test_logging_callback.py b/tests/client/test_logging_callback.py index 0c9eeb397..f298ee287 100644 --- a/tests/client/test_logging_callback.py +++ b/tests/client/test_logging_callback.py @@ -49,9 +49,7 @@ async def test_tool_with_log( # Create a message handler to catch exceptions async def message_handler( - message: RequestResponder[types.ServerRequest, types.ClientResult] - | types.ServerNotification - | Exception, + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: if isinstance(message, Exception): raise message diff --git a/tests/client/test_sampling_callback.py b/tests/client/test_sampling_callback.py index ba586d4a8..a3f6affda 100644 --- a/tests/client/test_sampling_callback.py +++ b/tests/client/test_sampling_callback.py @@ -21,9 +21,7 @@ async def test_sampling_callback(): callback_return = CreateMessageResult( role="assistant", - content=TextContent( - type="text", text="This is a response from the sampling callback" - ), + content=TextContent(type="text", text="This is a response from the sampling callback"), model="test-model", stopReason="endTurn", ) @@ -37,24 +35,16 @@ async def sampling_callback( @server.tool("test_sampling") async def test_sampling_tool(message: str): value = await server.get_context().session.create_message( - messages=[ - SamplingMessage( - role="user", content=TextContent(type="text", text=message) - ) - ], + messages=[SamplingMessage(role="user", content=TextContent(type="text", text=message))], max_tokens=100, ) assert value == callback_return return True # Test with sampling callback - async with create_session( - server._mcp_server, sampling_callback=sampling_callback - ) as client_session: + async with create_session(server._mcp_server, sampling_callback=sampling_callback) as client_session: # Make a request to trigger sampling callback - result = await client_session.call_tool( - "test_sampling", {"message": "Test message for sampling"} - ) + result = await client_session.call_tool("test_sampling", {"message": "Test message for sampling"}) assert result.isError is False assert isinstance(result.content[0], TextContent) assert result.content[0].text == "true" @@ -62,12 +52,7 @@ async def test_sampling_tool(message: str): # Test without sampling callback async with create_session(server._mcp_server) as client_session: # Make a request to trigger sampling callback - result = await client_session.call_tool( - "test_sampling", {"message": "Test message for sampling"} - ) + result = await client_session.call_tool("test_sampling", {"message": "Test message for sampling"}) assert result.isError is True assert isinstance(result.content[0], TextContent) - assert ( - result.content[0].text - == "Error executing tool test_sampling: Sampling not supported" - ) + assert result.content[0].text == "Error executing tool test_sampling: Sampling not supported" diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 72b4413d2..327d1a9e4 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -28,12 +28,8 @@ @pytest.mark.anyio async def test_client_session_initialize(): - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - SessionMessage - ](1) - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - SessionMessage - ](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) initialized_notification = None @@ -70,9 +66,7 @@ async def mock_server(): JSONRPCResponse( jsonrpc="2.0", id=jsonrpc_request.root.id, - result=result.model_dump( - by_alias=True, mode="json", exclude_none=True - ), + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) ) @@ -81,16 +75,12 @@ async def mock_server(): jsonrpc_notification = session_notification.message assert isinstance(jsonrpc_notification.root, JSONRPCNotification) initialized_notification = ClientNotification.model_validate( - jsonrpc_notification.model_dump( - by_alias=True, mode="json", exclude_none=True - ) + jsonrpc_notification.model_dump(by_alias=True, mode="json", exclude_none=True) ) # Create a message handler to catch exceptions async def message_handler( - message: RequestResponder[types.ServerRequest, types.ClientResult] - | types.ServerNotification - | Exception, + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: if isinstance(message, Exception): raise message @@ -124,12 +114,8 @@ async def message_handler( @pytest.mark.anyio async def test_client_session_custom_client_info(): - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - SessionMessage - ](1) - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - SessionMessage - ](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) custom_client_info = Implementation(name="test-client", version="1.2.3") received_client_info = None @@ -161,9 +147,7 @@ async def mock_server(): JSONRPCResponse( jsonrpc="2.0", id=jsonrpc_request.root.id, - result=result.model_dump( - by_alias=True, mode="json", exclude_none=True - ), + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) ) @@ -192,12 +176,8 @@ async def mock_server(): @pytest.mark.anyio async def test_client_session_default_client_info(): - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - SessionMessage - ](1) - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - SessionMessage - ](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) received_client_info = None @@ -228,9 +208,7 @@ async def mock_server(): JSONRPCResponse( jsonrpc="2.0", id=jsonrpc_request.root.id, - result=result.model_dump( - by_alias=True, mode="json", exclude_none=True - ), + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) ) @@ -259,12 +237,8 @@ async def mock_server(): @pytest.mark.anyio async def test_client_session_version_negotiation_success(): """Test successful version negotiation with supported version""" - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - SessionMessage - ](1) - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - SessionMessage - ](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) async def mock_server(): session_message = await client_to_server_receive.receive() @@ -294,9 +268,7 @@ async def mock_server(): JSONRPCResponse( jsonrpc="2.0", id=jsonrpc_request.root.id, - result=result.model_dump( - by_alias=True, mode="json", exclude_none=True - ), + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) ) @@ -327,12 +299,8 @@ async def mock_server(): @pytest.mark.anyio async def test_client_session_version_negotiation_failure(): """Test version negotiation failure with unsupported version""" - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - SessionMessage - ](1) - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - SessionMessage - ](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) async def mock_server(): session_message = await client_to_server_receive.receive() @@ -359,9 +327,7 @@ async def mock_server(): JSONRPCResponse( jsonrpc="2.0", id=jsonrpc_request.root.id, - result=result.model_dump( - by_alias=True, mode="json", exclude_none=True - ), + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) ) @@ -388,12 +354,8 @@ async def mock_server(): @pytest.mark.anyio async def test_client_capabilities_default(): """Test that client capabilities are properly set with default callbacks""" - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - SessionMessage - ](1) - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - SessionMessage - ](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) received_capabilities = None @@ -424,9 +386,7 @@ async def mock_server(): JSONRPCResponse( jsonrpc="2.0", id=jsonrpc_request.root.id, - result=result.model_dump( - by_alias=True, mode="json", exclude_none=True - ), + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) ) @@ -457,12 +417,8 @@ async def mock_server(): @pytest.mark.anyio async def test_client_capabilities_with_custom_callbacks(): """Test that client capabilities are properly set with custom callbacks""" - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - SessionMessage - ](1) - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - SessionMessage - ](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) received_capabilities = None @@ -508,9 +464,7 @@ async def mock_server(): JSONRPCResponse( jsonrpc="2.0", id=jsonrpc_request.root.id, - result=result.model_dump( - by_alias=True, mode="json", exclude_none=True - ), + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) ) @@ -536,14 +490,8 @@ async def mock_server(): # Assert that capabilities are properly set with custom callbacks assert received_capabilities is not None - assert ( - received_capabilities.sampling is not None - ) # Custom sampling callback provided + assert received_capabilities.sampling is not None # Custom sampling callback provided assert isinstance(received_capabilities.sampling, types.SamplingCapability) - assert ( - received_capabilities.roots is not None - ) # Custom list_roots callback provided + assert received_capabilities.roots is not None # Custom list_roots callback provided assert isinstance(received_capabilities.roots, types.RootsCapability) - assert ( - received_capabilities.roots.listChanged is True - ) # Should be True for custom callback + assert received_capabilities.roots.listChanged is True # Should be True for custom callback diff --git a/tests/client/test_session_group.py b/tests/client/test_session_group.py index 924ef7a06..16a887e00 100644 --- a/tests/client/test_session_group.py +++ b/tests/client/test_session_group.py @@ -58,14 +58,10 @@ def hook(name, server_info): return f"{(server_info.name)}-{name}" mcp_session_group = ClientSessionGroup(component_name_hook=hook) - mcp_session_group._tools = { - "server1-my_tool": types.Tool(name="my_tool", inputSchema={}) - } + mcp_session_group._tools = {"server1-my_tool": types.Tool(name="my_tool", inputSchema={})} mcp_session_group._tool_to_session = {"server1-my_tool": mock_session} text_content = types.TextContent(type="text", text="OK") - mock_session.call_tool.return_value = types.CallToolResult( - content=[text_content] - ) + mock_session.call_tool.return_value = types.CallToolResult(content=[text_content]) # --- Test Execution --- result = await mcp_session_group.call_tool( @@ -96,16 +92,12 @@ async def test_connect_to_server(self, mock_exit_stack): mock_prompt1 = mock.Mock(spec=types.Prompt) mock_prompt1.name = "prompt_c" mock_session.list_tools.return_value = mock.AsyncMock(tools=[mock_tool1]) - mock_session.list_resources.return_value = mock.AsyncMock( - resources=[mock_resource1] - ) + mock_session.list_resources.return_value = mock.AsyncMock(resources=[mock_resource1]) mock_session.list_prompts.return_value = mock.AsyncMock(prompts=[mock_prompt1]) # --- Test Execution --- group = ClientSessionGroup(exit_stack=mock_exit_stack) - with mock.patch.object( - group, "_establish_session", return_value=(mock_server_info, mock_session) - ): + with mock.patch.object(group, "_establish_session", return_value=(mock_server_info, mock_session)): await group.connect_to_server(StdioServerParameters(command="test")) # --- Assertions --- @@ -141,12 +133,8 @@ def name_hook(name: str, server_info: types.Implementation) -> str: return f"{server_info.name}.{name}" # --- Test Execution --- - group = ClientSessionGroup( - exit_stack=mock_exit_stack, component_name_hook=name_hook - ) - with mock.patch.object( - group, "_establish_session", return_value=(mock_server_info, mock_session) - ): + group = ClientSessionGroup(exit_stack=mock_exit_stack, component_name_hook=name_hook) + with mock.patch.object(group, "_establish_session", return_value=(mock_server_info, mock_session)): await group.connect_to_server(StdioServerParameters(command="test")) # --- Assertions --- @@ -231,9 +219,7 @@ async def test_connect_to_server_duplicate_tool_raises_error(self, mock_exit_sta # Need a dummy session associated with the existing tool mock_session = mock.MagicMock(spec=mcp.ClientSession) group._tool_to_session[existing_tool_name] = mock_session - group._session_exit_stacks[mock_session] = mock.Mock( - spec=contextlib.AsyncExitStack - ) + group._session_exit_stacks[mock_session] = mock.Mock(spec=contextlib.AsyncExitStack) # --- Mock New Connection Attempt --- mock_server_info_new = mock.Mock(spec=types.Implementation) @@ -243,9 +229,7 @@ async def test_connect_to_server_duplicate_tool_raises_error(self, mock_exit_sta # Configure the new session to return a tool with the *same name* duplicate_tool = mock.Mock(spec=types.Tool) duplicate_tool.name = existing_tool_name - mock_session_new.list_tools.return_value = mock.AsyncMock( - tools=[duplicate_tool] - ) + mock_session_new.list_tools.return_value = mock.AsyncMock(tools=[duplicate_tool]) # Keep other lists empty for simplicity mock_session_new.list_resources.return_value = mock.AsyncMock(resources=[]) mock_session_new.list_prompts.return_value = mock.AsyncMock(prompts=[]) @@ -266,9 +250,7 @@ async def test_connect_to_server_duplicate_tool_raises_error(self, mock_exit_sta # Verify the duplicate tool was *not* added again (state should be unchanged) assert len(group._tools) == 1 # Should still only have the original - assert ( - group._tools[existing_tool_name] is not duplicate_tool - ) # Ensure it's the original mock + assert group._tools[existing_tool_name] is not duplicate_tool # Ensure it's the original mock # No patching needed here async def test_disconnect_non_existent_server(self): @@ -292,9 +274,7 @@ async def test_disconnect_non_existent_server(self): "mcp.client.session_group.sse_client", ), # url, headers, timeout, sse_read_timeout ( - StreamableHttpParameters( - url="http://test.com/stream", terminate_on_close=False - ), + StreamableHttpParameters(url="http://test.com/stream", terminate_on_close=False), "streamablehttp", "mcp.client.session_group.streamablehttp_client", ), # url, headers, timeout, sse_read_timeout, terminate_on_close @@ -306,13 +286,9 @@ async def test_establish_session_parameterized( client_type_name, # Just for clarity or conditional logic if needed patch_target_for_client_func, ): - with mock.patch( - "mcp.client.session_group.mcp.ClientSession" - ) as mock_ClientSession_class: + with mock.patch("mcp.client.session_group.mcp.ClientSession") as mock_ClientSession_class: with mock.patch(patch_target_for_client_func) as mock_specific_client_func: - mock_client_cm_instance = mock.AsyncMock( - name=f"{client_type_name}ClientCM" - ) + mock_client_cm_instance = mock.AsyncMock(name=f"{client_type_name}ClientCM") mock_read_stream = mock.AsyncMock(name=f"{client_type_name}Read") mock_write_stream = mock.AsyncMock(name=f"{client_type_name}Write") @@ -344,9 +320,7 @@ async def test_establish_session_parameterized( # Mock session.initialize() mock_initialize_result = mock.AsyncMock(name="InitializeResult") - mock_initialize_result.serverInfo = types.Implementation( - name="foo", version="1" - ) + mock_initialize_result.serverInfo = types.Implementation(name="foo", version="1") mock_entered_session.initialize.return_value = mock_initialize_result # --- Test Execution --- @@ -364,9 +338,7 @@ async def test_establish_session_parameterized( # --- Assertions --- # 1. Assert the correct specific client function was called if client_type_name == "stdio": - mock_specific_client_func.assert_called_once_with( - server_params_instance - ) + mock_specific_client_func.assert_called_once_with(server_params_instance) elif client_type_name == "sse": mock_specific_client_func.assert_called_once_with( url=server_params_instance.url, @@ -386,9 +358,7 @@ async def test_establish_session_parameterized( mock_client_cm_instance.__aenter__.assert_awaited_once() # 2. Assert ClientSession was called correctly - mock_ClientSession_class.assert_called_once_with( - mock_read_stream, mock_write_stream - ) + mock_ClientSession_class.assert_called_once_with(mock_read_stream, mock_write_stream) mock_raw_session_cm.__aenter__.assert_awaited_once() mock_entered_session.initialize.assert_awaited_once() diff --git a/tests/client/test_stdio.py b/tests/client/test_stdio.py index 1c6ffe000..c66a16ab9 100644 --- a/tests/client/test_stdio.py +++ b/tests/client/test_stdio.py @@ -50,20 +50,14 @@ async def test_stdio_client(): break assert len(read_messages) == 2 - assert read_messages[0] == JSONRPCMessage( - root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") - ) - assert read_messages[1] == JSONRPCMessage( - root=JSONRPCResponse(jsonrpc="2.0", id=2, result={}) - ) + assert read_messages[0] == JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")) + assert read_messages[1] == JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})) @pytest.mark.anyio async def test_stdio_client_bad_path(): """Check that the connection doesn't hang if process errors.""" - server_params = StdioServerParameters( - command="python", args=["-c", "non-existent-file.py"] - ) + server_params = StdioServerParameters(command="python", args=["-c", "non-existent-file.py"]) async with stdio_client(server_params) as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: # The session should raise an error when the connection closes diff --git a/tests/issues/test_100_tool_listing.py b/tests/issues/test_100_tool_listing.py index 2bc386c96..6dccec84d 100644 --- a/tests/issues/test_100_tool_listing.py +++ b/tests/issues/test_100_tool_listing.py @@ -17,9 +17,7 @@ def dummy_tool_func(): f"""Tool number {i}""" return i - globals()[f"dummy_tool_{i}"] = ( - dummy_tool_func # Keep reference to avoid garbage collection - ) + globals()[f"dummy_tool_{i}"] = dummy_tool_func # Keep reference to avoid garbage collection # Get all tools tools = await mcp.list_tools() @@ -30,6 +28,4 @@ def dummy_tool_func(): # Verify each tool is unique and has the correct name tool_names = [tool.name for tool in tools] expected_names = [f"tool_{i}" for i in range(num_tools)] - assert sorted(tool_names) == sorted( - expected_names - ), "Tool names don't match expected names" + assert sorted(tool_names) == sorted(expected_names), "Tool names don't match expected names" diff --git a/tests/issues/test_129_resource_templates.py b/tests/issues/test_129_resource_templates.py index 314952303..4bedb15d5 100644 --- a/tests/issues/test_129_resource_templates.py +++ b/tests/issues/test_129_resource_templates.py @@ -24,9 +24,7 @@ def get_user_profile(user_id: str) -> str: # Note: list_resource_templates() returns a decorator that wraps the handler # The handler returns a ServerResult with a ListResourceTemplatesResult inside result = await mcp._mcp_server.request_handlers[types.ListResourceTemplatesRequest]( - types.ListResourceTemplatesRequest( - method="resources/templates/list", params=None - ) + types.ListResourceTemplatesRequest(method="resources/templates/list", params=None) ) assert isinstance(result.root, types.ListResourceTemplatesResult) templates = result.root.resourceTemplates diff --git a/tests/issues/test_141_resource_templates.py b/tests/issues/test_141_resource_templates.py index 3c17cd559..3145f65e8 100644 --- a/tests/issues/test_141_resource_templates.py +++ b/tests/issues/test_141_resource_templates.py @@ -61,9 +61,7 @@ def get_user_profile_missing(user_id: str) -> str: await mcp.read_resource("resource://users/123/posts") # Missing post_id with pytest.raises(ValueError, match="Unknown resource"): - await mcp.read_resource( - "resource://users/123/posts/456/extra" - ) # Extra path component + await mcp.read_resource("resource://users/123/posts/456/extra") # Extra path component @pytest.mark.anyio @@ -110,11 +108,7 @@ def get_user_profile(user_id: str) -> str: # Verify invalid resource URIs raise appropriate errors with pytest.raises(Exception): # Specific exception type may vary - await session.read_resource( - AnyUrl("resource://users/123/posts") - ) # Missing post_id + await session.read_resource(AnyUrl("resource://users/123/posts")) # Missing post_id with pytest.raises(Exception): # Specific exception type may vary - await session.read_resource( - AnyUrl("resource://users/123/invalid") - ) # Invalid template + await session.read_resource(AnyUrl("resource://users/123/invalid")) # Invalid template diff --git a/tests/issues/test_152_resource_mime_type.py b/tests/issues/test_152_resource_mime_type.py index 1143195e5..a99e5a5c7 100644 --- a/tests/issues/test_152_resource_mime_type.py +++ b/tests/issues/test_152_resource_mime_type.py @@ -45,31 +45,19 @@ def get_image_as_bytes() -> bytes: bytes_resource = mapping["test://image_bytes"] # Verify mime types - assert ( - string_resource.mimeType == "image/png" - ), "String resource mime type not respected" - assert ( - bytes_resource.mimeType == "image/png" - ), "Bytes resource mime type not respected" + assert string_resource.mimeType == "image/png", "String resource mime type not respected" + assert bytes_resource.mimeType == "image/png", "Bytes resource mime type not respected" # Also verify the content can be read correctly string_result = await client.read_resource(AnyUrl("test://image")) assert len(string_result.contents) == 1 - assert ( - getattr(string_result.contents[0], "text") == base64_string - ), "Base64 string mismatch" - assert ( - string_result.contents[0].mimeType == "image/png" - ), "String content mime type not preserved" + assert getattr(string_result.contents[0], "text") == base64_string, "Base64 string mismatch" + assert string_result.contents[0].mimeType == "image/png", "String content mime type not preserved" bytes_result = await client.read_resource(AnyUrl("test://image_bytes")) assert len(bytes_result.contents) == 1 - assert ( - base64.b64decode(getattr(bytes_result.contents[0], "blob")) == image_bytes - ), "Bytes mismatch" - assert ( - bytes_result.contents[0].mimeType == "image/png" - ), "Bytes content mime type not preserved" + assert base64.b64decode(getattr(bytes_result.contents[0], "blob")) == image_bytes, "Bytes mismatch" + assert bytes_result.contents[0].mimeType == "image/png", "Bytes content mime type not preserved" async def test_lowlevel_resource_mime_type(): @@ -82,9 +70,7 @@ async def test_lowlevel_resource_mime_type(): # Create test resources with specific mime types test_resources = [ - types.Resource( - uri=AnyUrl("test://image"), name="test image", mimeType="image/png" - ), + types.Resource(uri=AnyUrl("test://image"), name="test image", mimeType="image/png"), types.Resource( uri=AnyUrl("test://image_bytes"), name="test image bytes", @@ -101,9 +87,7 @@ async def handle_read_resource(uri: AnyUrl): if str(uri) == "test://image": return [ReadResourceContents(content=base64_string, mime_type="image/png")] elif str(uri) == "test://image_bytes": - return [ - ReadResourceContents(content=bytes(image_bytes), mime_type="image/png") - ] + return [ReadResourceContents(content=bytes(image_bytes), mime_type="image/png")] raise Exception(f"Resource not found: {uri}") # Test that resources are listed with correct mime type @@ -119,28 +103,16 @@ async def handle_read_resource(uri: AnyUrl): bytes_resource = mapping["test://image_bytes"] # Verify mime types - assert ( - string_resource.mimeType == "image/png" - ), "String resource mime type not respected" - assert ( - bytes_resource.mimeType == "image/png" - ), "Bytes resource mime type not respected" + assert string_resource.mimeType == "image/png", "String resource mime type not respected" + assert bytes_resource.mimeType == "image/png", "Bytes resource mime type not respected" # Also verify the content can be read correctly string_result = await client.read_resource(AnyUrl("test://image")) assert len(string_result.contents) == 1 - assert ( - getattr(string_result.contents[0], "text") == base64_string - ), "Base64 string mismatch" - assert ( - string_result.contents[0].mimeType == "image/png" - ), "String content mime type not preserved" + assert getattr(string_result.contents[0], "text") == base64_string, "Base64 string mismatch" + assert string_result.contents[0].mimeType == "image/png", "String content mime type not preserved" bytes_result = await client.read_resource(AnyUrl("test://image_bytes")) assert len(bytes_result.contents) == 1 - assert ( - base64.b64decode(getattr(bytes_result.contents[0], "blob")) == image_bytes - ), "Bytes mismatch" - assert ( - bytes_result.contents[0].mimeType == "image/png" - ), "Bytes content mime type not preserved" + assert base64.b64decode(getattr(bytes_result.contents[0], "blob")) == image_bytes, "Bytes mismatch" + assert bytes_result.contents[0].mimeType == "image/png", "Bytes content mime type not preserved" diff --git a/tests/issues/test_176_progress_token.py b/tests/issues/test_176_progress_token.py index 4ad22f294..eb5f19d64 100644 --- a/tests/issues/test_176_progress_token.py +++ b/tests/issues/test_176_progress_token.py @@ -35,15 +35,7 @@ async def test_progress_token_zero_first_call(): await ctx.report_progress(10, 10) # Complete # Verify progress notifications - assert ( - mock_session.send_progress_notification.call_count == 3 - ), "All progress notifications should be sent" - mock_session.send_progress_notification.assert_any_call( - progress_token=0, progress=0.0, total=10.0, message=None - ) - mock_session.send_progress_notification.assert_any_call( - progress_token=0, progress=5.0, total=10.0, message=None - ) - mock_session.send_progress_notification.assert_any_call( - progress_token=0, progress=10.0, total=10.0, message=None - ) + assert mock_session.send_progress_notification.call_count == 3, "All progress notifications should be sent" + mock_session.send_progress_notification.assert_any_call(progress_token=0, progress=0.0, total=10.0, message=None) + mock_session.send_progress_notification.assert_any_call(progress_token=0, progress=5.0, total=10.0, message=None) + mock_session.send_progress_notification.assert_any_call(progress_token=0, progress=10.0, total=10.0, message=None) diff --git a/tests/issues/test_192_request_id.py b/tests/issues/test_192_request_id.py index cf5eb6083..3c63f00b7 100644 --- a/tests/issues/test_192_request_id.py +++ b/tests/issues/test_192_request_id.py @@ -66,9 +66,7 @@ async def run_server(): ) await client_writer.send(SessionMessage(JSONRPCMessage(root=init_req))) - response = ( - await server_reader.receive() - ) # Get init response but don't need to check it + response = await server_reader.receive() # Get init response but don't need to check it # Send initialized notification initialized_notification = JSONRPCNotification( @@ -76,14 +74,10 @@ async def run_server(): params=NotificationParams().model_dump(by_alias=True, exclude_none=True), jsonrpc="2.0", ) - await client_writer.send( - SessionMessage(JSONRPCMessage(root=initialized_notification)) - ) + await client_writer.send(SessionMessage(JSONRPCMessage(root=initialized_notification))) # Send ping request with custom ID - ping_request = JSONRPCRequest( - id=custom_request_id, method="ping", params={}, jsonrpc="2.0" - ) + ping_request = JSONRPCRequest(id=custom_request_id, method="ping", params={}, jsonrpc="2.0") await client_writer.send(SessionMessage(JSONRPCMessage(root=ping_request))) @@ -91,9 +85,7 @@ async def run_server(): response = await server_reader.receive() # Verify response ID matches request ID - assert ( - response.message.root.id == custom_request_id - ), "Response ID should match request ID" + assert response.message.root.id == custom_request_id, "Response ID should match request ID" # Cancel server task tg.cancel_scope.cancel() diff --git a/tests/issues/test_342_base64_encoding.py b/tests/issues/test_342_base64_encoding.py index cff8ec543..6a6e410c7 100644 --- a/tests/issues/test_342_base64_encoding.py +++ b/tests/issues/test_342_base64_encoding.py @@ -47,11 +47,7 @@ async def test_server_base64_encoding_issue(): # Register a resource handler that returns our test data @server.read_resource() async def read_resource(uri: AnyUrl) -> list[ReadResourceContents]: - return [ - ReadResourceContents( - content=binary_data, mime_type="application/octet-stream" - ) - ] + return [ReadResourceContents(content=binary_data, mime_type="application/octet-stream")] # Get the handler directly from the server handler = server.request_handlers[ReadResourceRequest] diff --git a/tests/issues/test_88_random_error.py b/tests/issues/test_88_random_error.py index 88e41d66d..53701824e 100644 --- a/tests/issues/test_88_random_error.py +++ b/tests/issues/test_88_random_error.py @@ -35,9 +35,7 @@ async def test_notification_validation_error(tmp_path: Path): slow_request_complete = anyio.Event() @server.call_tool() - async def slow_tool( - name: str, arg - ) -> Sequence[TextContent | ImageContent | EmbeddedResource]: + async def slow_tool(name: str, arg) -> Sequence[TextContent | ImageContent | EmbeddedResource]: nonlocal request_count request_count += 1 @@ -74,9 +72,7 @@ async def client(read_stream, write_stream, scope): # - Long enough for fast operations (>10ms) # - Short enough for slow operations (<200ms) # - Not too short to avoid flakiness - async with ClientSession( - read_stream, write_stream, read_timeout_seconds=timedelta(milliseconds=50) - ) as session: + async with ClientSession(read_stream, write_stream, read_timeout_seconds=timedelta(milliseconds=50)) as session: await session.initialize() # First call should work (fast operation) diff --git a/tests/server/auth/middleware/test_bearer_auth.py b/tests/server/auth/middleware/test_bearer_auth.py index e8c17a4c4..79b813096 100644 --- a/tests/server/auth/middleware/test_bearer_auth.py +++ b/tests/server/auth/middleware/test_bearer_auth.py @@ -116,18 +116,14 @@ def no_expiry_access_token() -> AccessToken: class TestBearerAuthBackend: """Tests for the BearerAuthBackend class.""" - async def test_no_auth_header( - self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any] - ): + async def test_no_auth_header(self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]): """Test authentication with no Authorization header.""" backend = BearerAuthBackend(provider=mock_oauth_provider) request = Request({"type": "http", "headers": []}) result = await backend.authenticate(request) assert result is None - async def test_non_bearer_auth_header( - self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any] - ): + async def test_non_bearer_auth_header(self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]): """Test authentication with non-Bearer Authorization header.""" backend = BearerAuthBackend(provider=mock_oauth_provider) request = Request( @@ -139,9 +135,7 @@ async def test_non_bearer_auth_header( result = await backend.authenticate(request) assert result is None - async def test_invalid_token( - self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any] - ): + async def test_invalid_token(self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]): """Test authentication with invalid token.""" backend = BearerAuthBackend(provider=mock_oauth_provider) request = Request( @@ -160,9 +154,7 @@ async def test_expired_token( ): """Test authentication with expired token.""" backend = BearerAuthBackend(provider=mock_oauth_provider) - add_token_to_provider( - mock_oauth_provider, "expired_token", expired_access_token - ) + add_token_to_provider(mock_oauth_provider, "expired_token", expired_access_token) request = Request( { "type": "http", @@ -203,9 +195,7 @@ async def test_token_without_expiry( ): """Test authentication with token that has no expiry.""" backend = BearerAuthBackend(provider=mock_oauth_provider) - add_token_to_provider( - mock_oauth_provider, "no_expiry_token", no_expiry_access_token - ) + add_token_to_provider(mock_oauth_provider, "no_expiry_token", no_expiry_access_token) request = Request( { "type": "http", diff --git a/tests/server/auth/test_error_handling.py b/tests/server/auth/test_error_handling.py index 18e9933e7..7846c8adb 100644 --- a/tests/server/auth/test_error_handling.py +++ b/tests/server/auth/test_error_handling.py @@ -128,16 +128,12 @@ async def test_registration_error_handling(self, client, oauth_provider): class TestAuthorizeErrorHandling: @pytest.mark.anyio - async def test_authorize_error_handling( - self, client, oauth_provider, registered_client, pkce_challenge - ): + async def test_authorize_error_handling(self, client, oauth_provider, registered_client, pkce_challenge): # Mock the authorize method to raise an authorize error with unittest.mock.patch.object( oauth_provider, "authorize", - side_effect=AuthorizeError( - error="access_denied", error_description="The user denied the request" - ), + side_effect=AuthorizeError(error="access_denied", error_description="The user denied the request"), ): # Register the client client_id = registered_client["client_id"] @@ -169,9 +165,7 @@ async def test_authorize_error_handling( class TestTokenErrorHandling: @pytest.mark.anyio - async def test_token_error_handling_auth_code( - self, client, oauth_provider, registered_client, pkce_challenge - ): + async def test_token_error_handling_auth_code(self, client, oauth_provider, registered_client, pkce_challenge): # Register the client and get an auth code client_id = registered_client["client_id"] client_secret = registered_client["client_secret"] @@ -224,9 +218,7 @@ async def test_token_error_handling_auth_code( assert data["error_description"] == "The authorization code is invalid" @pytest.mark.anyio - async def test_token_error_handling_refresh_token( - self, client, oauth_provider, registered_client, pkce_challenge - ): + async def test_token_error_handling_refresh_token(self, client, oauth_provider, registered_client, pkce_challenge): # Register the client and get tokens client_id = registered_client["client_id"] client_secret = registered_client["client_secret"] diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index d237e860e..0f7e7132d 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -47,9 +47,7 @@ async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: async def register_client(self, client_info: OAuthClientInformationFull): self.clients[client_info.client_id] = client_info - async def authorize( - self, client: OAuthClientInformationFull, params: AuthorizationParams - ) -> str: + async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str: # toy authorize implementation which just immediately generates an authorization # code and completes the redirect code = AuthorizationCode( @@ -63,9 +61,7 @@ async def authorize( ) self.auth_codes[code.code] = code - return construct_redirect_uri( - str(params.redirect_uri), code=code.code, state=params.state - ) + return construct_redirect_uri(str(params.redirect_uri), code=code.code, state=params.state) async def load_authorization_code( self, client: OAuthClientInformationFull, authorization_code: str @@ -102,9 +98,7 @@ async def exchange_authorization_code( refresh_token=refresh_token, ) - async def load_refresh_token( - self, client: OAuthClientInformationFull, refresh_token: str - ) -> RefreshToken | None: + async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_token: str) -> RefreshToken | None: old_access_token = self.refresh_tokens.get(refresh_token) if old_access_token is None: return None @@ -224,9 +218,7 @@ def auth_app(mock_oauth_provider): @pytest.fixture async def test_client(auth_app): - async with httpx.AsyncClient( - transport=httpx.ASGITransport(app=auth_app), base_url="https://mcptest.com" - ) as client: + async with httpx.AsyncClient(transport=httpx.ASGITransport(app=auth_app), base_url="https://mcptest.com") as client: yield client @@ -261,11 +253,7 @@ async def registered_client(test_client: httpx.AsyncClient, request): def pkce_challenge(): """Create a PKCE challenge with code_verifier and code_challenge.""" code_verifier = "some_random_verifier_string" - code_challenge = ( - base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()) - .decode() - .rstrip("=") - ) + code_challenge = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()).decode().rstrip("=") return {"code_verifier": code_verifier, "code_challenge": code_challenge} @@ -356,17 +344,13 @@ async def test_metadata_endpoint(self, test_client: httpx.AsyncClient): metadata = response.json() assert metadata["issuer"] == "https://auth.example.com/" - assert ( - metadata["authorization_endpoint"] == "https://auth.example.com/authorize" - ) + assert metadata["authorization_endpoint"] == "https://auth.example.com/authorize" assert metadata["token_endpoint"] == "https://auth.example.com/token" assert metadata["registration_endpoint"] == "https://auth.example.com/register" assert metadata["revocation_endpoint"] == "https://auth.example.com/revoke" assert metadata["response_types_supported"] == ["code"] assert metadata["code_challenge_methods_supported"] == ["S256"] - assert metadata["token_endpoint_auth_methods_supported"] == [ - "client_secret_post" - ] + assert metadata["token_endpoint_auth_methods_supported"] == ["client_secret_post"] assert metadata["grant_types_supported"] == [ "authorization_code", "refresh_token", @@ -386,14 +370,10 @@ async def test_token_validation_error(self, test_client: httpx.AsyncClient): ) error_response = response.json() assert error_response["error"] == "invalid_request" - assert ( - "error_description" in error_response - ) # Contains validation error messages + assert "error_description" in error_response # Contains validation error messages @pytest.mark.anyio - async def test_token_invalid_auth_code( - self, test_client, registered_client, pkce_challenge - ): + async def test_token_invalid_auth_code(self, test_client, registered_client, pkce_challenge): """Test token endpoint error - authorization code does not exist.""" # Try to use a non-existent authorization code response = await test_client.post( @@ -413,9 +393,7 @@ async def test_token_invalid_auth_code( assert response.status_code == 400 error_response = response.json() assert error_response["error"] == "invalid_grant" - assert ( - "authorization code does not exist" in error_response["error_description"] - ) + assert "authorization code does not exist" in error_response["error_description"] @pytest.mark.anyio async def test_token_expired_auth_code( @@ -458,9 +436,7 @@ async def test_token_expired_auth_code( assert response.status_code == 400 error_response = response.json() assert error_response["error"] == "invalid_grant" - assert ( - "authorization code has expired" in error_response["error_description"] - ) + assert "authorization code has expired" in error_response["error_description"] @pytest.mark.anyio @pytest.mark.parametrize( @@ -475,9 +451,7 @@ async def test_token_expired_auth_code( ], indirect=True, ) - async def test_token_redirect_uri_mismatch( - self, test_client, registered_client, auth_code, pkce_challenge - ): + async def test_token_redirect_uri_mismatch(self, test_client, registered_client, auth_code, pkce_challenge): """Test token endpoint error - redirect URI mismatch.""" # Try to use the code with a different redirect URI response = await test_client.post( @@ -498,9 +472,7 @@ async def test_token_redirect_uri_mismatch( assert "redirect_uri did not match" in error_response["error_description"] @pytest.mark.anyio - async def test_token_code_verifier_mismatch( - self, test_client, registered_client, auth_code - ): + async def test_token_code_verifier_mismatch(self, test_client, registered_client, auth_code): """Test token endpoint error - PKCE code verifier mismatch.""" # Try to use the code with an incorrect code verifier response = await test_client.post( @@ -569,9 +541,7 @@ async def test_token_expired_refresh_token( # Step 2: Time travel forward 4 hours (tokens expire in 1 hour by default) # Mock the time.time() function to return a value 4 hours in the future - with unittest.mock.patch( - "time.time", return_value=current_time + 14400 - ): # 4 hours = 14400 seconds + with unittest.mock.patch("time.time", return_value=current_time + 14400): # 4 hours = 14400 seconds # Try to use the refresh token which should now be considered expired response = await test_client.post( "/token", @@ -590,9 +560,7 @@ async def test_token_expired_refresh_token( assert "refresh token has expired" in error_response["error_description"] @pytest.mark.anyio - async def test_token_invalid_scope( - self, test_client, registered_client, auth_code, pkce_challenge - ): + async def test_token_invalid_scope(self, test_client, registered_client, auth_code, pkce_challenge): """Test token endpoint error - invalid scope in refresh token request.""" # Exchange authorization code for tokens token_response = await test_client.post( @@ -628,9 +596,7 @@ async def test_token_invalid_scope( assert "cannot request scope" in error_response["error_description"] @pytest.mark.anyio - async def test_client_registration( - self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider - ): + async def test_client_registration(self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider): """Test client registration.""" client_metadata = { "redirect_uris": ["https://client.example.com/callback"], @@ -656,9 +622,7 @@ async def test_client_registration( # ) is not None @pytest.mark.anyio - async def test_client_registration_missing_required_fields( - self, test_client: httpx.AsyncClient - ): + async def test_client_registration_missing_required_fields(self, test_client: httpx.AsyncClient): """Test client registration with missing required fields.""" # Missing redirect_uris which is a required field client_metadata = { @@ -677,9 +641,7 @@ async def test_client_registration_missing_required_fields( assert error_data["error_description"] == "redirect_uris: Field required" @pytest.mark.anyio - async def test_client_registration_invalid_uri( - self, test_client: httpx.AsyncClient - ): + async def test_client_registration_invalid_uri(self, test_client: httpx.AsyncClient): """Test client registration with invalid URIs.""" # Invalid redirect_uri format client_metadata = { @@ -696,14 +658,11 @@ async def test_client_registration_invalid_uri( assert "error" in error_data assert error_data["error"] == "invalid_client_metadata" assert error_data["error_description"] == ( - "redirect_uris.0: Input should be a valid URL, " - "relative URL without a base" + "redirect_uris.0: Input should be a valid URL, " "relative URL without a base" ) @pytest.mark.anyio - async def test_client_registration_empty_redirect_uris( - self, test_client: httpx.AsyncClient - ): + async def test_client_registration_empty_redirect_uris(self, test_client: httpx.AsyncClient): """Test client registration with empty redirect_uris array.""" client_metadata = { "redirect_uris": [], # Empty array @@ -719,8 +678,7 @@ async def test_client_registration_empty_redirect_uris( assert "error" in error_data assert error_data["error"] == "invalid_client_metadata" assert ( - error_data["error_description"] - == "redirect_uris: List should have at least 1 item after validation, not 0" + error_data["error_description"] == "redirect_uris: List should have at least 1 item after validation, not 0" ) @pytest.mark.anyio @@ -875,12 +833,7 @@ async def test_authorization_get( assert response.status_code == 200 # Verify that the token was revoked - assert ( - await mock_oauth_provider.load_access_token( - new_token_response["access_token"] - ) - is None - ) + assert await mock_oauth_provider.load_access_token(new_token_response["access_token"]) is None @pytest.mark.anyio async def test_revoke_invalid_token(self, test_client, registered_client): @@ -913,9 +866,7 @@ async def test_revoke_with_malformed_token(self, test_client, registered_client) assert "token_type_hint" in error_response["error_description"] @pytest.mark.anyio - async def test_client_registration_disallowed_scopes( - self, test_client: httpx.AsyncClient - ): + async def test_client_registration_disallowed_scopes(self, test_client: httpx.AsyncClient): """Test client registration with scopes that are not allowed.""" client_metadata = { "redirect_uris": ["https://client.example.com/callback"], @@ -955,18 +906,14 @@ async def test_client_registration_default_scopes( assert client_info["scope"] == "read write" # Retrieve the client from the store to verify default scopes - registered_client = await mock_oauth_provider.get_client( - client_info["client_id"] - ) + registered_client = await mock_oauth_provider.get_client(client_info["client_id"]) assert registered_client is not None # Check that default scopes were applied assert registered_client.scope == "read write" @pytest.mark.anyio - async def test_client_registration_invalid_grant_type( - self, test_client: httpx.AsyncClient - ): + async def test_client_registration_invalid_grant_type(self, test_client: httpx.AsyncClient): client_metadata = { "redirect_uris": ["https://client.example.com/callback"], "client_name": "Test Client", @@ -981,19 +928,14 @@ async def test_client_registration_invalid_grant_type( error_data = response.json() assert "error" in error_data assert error_data["error"] == "invalid_client_metadata" - assert ( - error_data["error_description"] - == "grant_types must be authorization_code and refresh_token" - ) + assert error_data["error_description"] == "grant_types must be authorization_code and refresh_token" class TestAuthorizeEndpointErrors: """Test error handling in the OAuth authorization endpoint.""" @pytest.mark.anyio - async def test_authorize_missing_client_id( - self, test_client: httpx.AsyncClient, pkce_challenge - ): + async def test_authorize_missing_client_id(self, test_client: httpx.AsyncClient, pkce_challenge): """Test authorization endpoint with missing client_id. According to the OAuth2.0 spec, if client_id is missing, the server should @@ -1017,9 +959,7 @@ async def test_authorize_missing_client_id( assert "client_id" in response.text.lower() @pytest.mark.anyio - async def test_authorize_invalid_client_id( - self, test_client: httpx.AsyncClient, pkce_challenge - ): + async def test_authorize_invalid_client_id(self, test_client: httpx.AsyncClient, pkce_challenge): """Test authorization endpoint with invalid client_id. According to the OAuth2.0 spec, if client_id is invalid, the server should @@ -1202,9 +1142,7 @@ async def test_authorize_missing_response_type( assert query_params["state"][0] == "test_state" @pytest.mark.anyio - async def test_authorize_missing_pkce_challenge( - self, test_client: httpx.AsyncClient, registered_client - ): + async def test_authorize_missing_pkce_challenge(self, test_client: httpx.AsyncClient, registered_client): """Test authorization endpoint with missing PKCE code_challenge. Missing PKCE parameters should result in invalid_request error. @@ -1233,9 +1171,7 @@ async def test_authorize_missing_pkce_challenge( assert query_params["state"][0] == "test_state" @pytest.mark.anyio - async def test_authorize_invalid_scope( - self, test_client: httpx.AsyncClient, registered_client, pkce_challenge - ): + async def test_authorize_invalid_scope(self, test_client: httpx.AsyncClient, registered_client, pkce_challenge): """Test authorization endpoint with invalid scope. Invalid scope should redirect with invalid_scope error. diff --git a/tests/server/fastmcp/prompts/test_base.py b/tests/server/fastmcp/prompts/test_base.py index c4af044a6..5b7b50e63 100644 --- a/tests/server/fastmcp/prompts/test_base.py +++ b/tests/server/fastmcp/prompts/test_base.py @@ -18,9 +18,7 @@ def fn() -> str: return "Hello, world!" prompt = Prompt.from_function(fn) - assert await prompt.render() == [ - UserMessage(content=TextContent(type="text", text="Hello, world!")) - ] + assert await prompt.render() == [UserMessage(content=TextContent(type="text", text="Hello, world!"))] @pytest.mark.anyio async def test_async_fn(self): @@ -28,9 +26,7 @@ async def fn() -> str: return "Hello, world!" prompt = Prompt.from_function(fn) - assert await prompt.render() == [ - UserMessage(content=TextContent(type="text", text="Hello, world!")) - ] + assert await prompt.render() == [UserMessage(content=TextContent(type="text", text="Hello, world!"))] @pytest.mark.anyio async def test_fn_with_args(self): @@ -39,11 +35,7 @@ async def fn(name: str, age: int = 30) -> str: prompt = Prompt.from_function(fn) assert await prompt.render(arguments={"name": "World"}) == [ - UserMessage( - content=TextContent( - type="text", text="Hello, World! You're 30 years old." - ) - ) + UserMessage(content=TextContent(type="text", text="Hello, World! You're 30 years old.")) ] @pytest.mark.anyio @@ -61,21 +53,15 @@ async def fn() -> UserMessage: return UserMessage(content="Hello, world!") prompt = Prompt.from_function(fn) - assert await prompt.render() == [ - UserMessage(content=TextContent(type="text", text="Hello, world!")) - ] + assert await prompt.render() == [UserMessage(content=TextContent(type="text", text="Hello, world!"))] @pytest.mark.anyio async def test_fn_returns_assistant_message(self): async def fn() -> AssistantMessage: - return AssistantMessage( - content=TextContent(type="text", text="Hello, world!") - ) + return AssistantMessage(content=TextContent(type="text", text="Hello, world!")) prompt = Prompt.from_function(fn) - assert await prompt.render() == [ - AssistantMessage(content=TextContent(type="text", text="Hello, world!")) - ] + assert await prompt.render() == [AssistantMessage(content=TextContent(type="text", text="Hello, world!"))] @pytest.mark.anyio async def test_fn_returns_multiple_messages(self): @@ -156,9 +142,7 @@ async def fn() -> list[Message]: prompt = Prompt.from_function(fn) assert await prompt.render() == [ - UserMessage( - content=TextContent(type="text", text="Please analyze this file:") - ), + UserMessage(content=TextContent(type="text", text="Please analyze this file:")), UserMessage( content=EmbeddedResource( type="resource", @@ -169,9 +153,7 @@ async def fn() -> list[Message]: ), ) ), - AssistantMessage( - content=TextContent(type="text", text="I'll help analyze that file.") - ), + AssistantMessage(content=TextContent(type="text", text="I'll help analyze that file.")), ] @pytest.mark.anyio diff --git a/tests/server/fastmcp/prompts/test_manager.py b/tests/server/fastmcp/prompts/test_manager.py index c64a4a564..82b234638 100644 --- a/tests/server/fastmcp/prompts/test_manager.py +++ b/tests/server/fastmcp/prompts/test_manager.py @@ -72,9 +72,7 @@ def fn() -> str: prompt = Prompt.from_function(fn) manager.add_prompt(prompt) messages = await manager.render_prompt("fn") - assert messages == [ - UserMessage(content=TextContent(type="text", text="Hello, world!")) - ] + assert messages == [UserMessage(content=TextContent(type="text", text="Hello, world!"))] @pytest.mark.anyio async def test_render_prompt_with_args(self): @@ -87,9 +85,7 @@ def fn(name: str) -> str: prompt = Prompt.from_function(fn) manager.add_prompt(prompt) messages = await manager.render_prompt("fn", arguments={"name": "World"}) - assert messages == [ - UserMessage(content=TextContent(type="text", text="Hello, World!")) - ] + assert messages == [UserMessage(content=TextContent(type="text", text="Hello, World!"))] @pytest.mark.anyio async def test_render_unknown_prompt(self): diff --git a/tests/server/fastmcp/resources/test_file_resources.py b/tests/server/fastmcp/resources/test_file_resources.py index 36cbca32c..ec3c85d8d 100644 --- a/tests/server/fastmcp/resources/test_file_resources.py +++ b/tests/server/fastmcp/resources/test_file_resources.py @@ -100,9 +100,7 @@ async def test_missing_file_error(self, temp_file: Path): with pytest.raises(ValueError, match="Error reading file"): await resource.read() - @pytest.mark.skipif( - os.name == "nt", reason="File permissions behave differently on Windows" - ) + @pytest.mark.skipif(os.name == "nt", reason="File permissions behave differently on Windows") @pytest.mark.anyio async def test_permission_error(self, temp_file: Path): """Test reading a file without permissions.""" diff --git a/tests/server/fastmcp/test_func_metadata.py b/tests/server/fastmcp/test_func_metadata.py index b1828ffe9..b13685e88 100644 --- a/tests/server/fastmcp/test_func_metadata.py +++ b/tests/server/fastmcp/test_func_metadata.py @@ -28,9 +28,7 @@ def complex_arguments_fn( # list[str] | str is an interesting case because if it comes in as JSON like # "[\"a\", \"b\"]" then it will be naively parsed as a string. list_str_or_str: list[str] | str, - an_int_annotated_with_field: Annotated[ - int, Field(description="An int with a field") - ], + an_int_annotated_with_field: Annotated[int, Field(description="An int with a field")], an_int_annotated_with_field_and_others: Annotated[ int, str, # Should be ignored, really @@ -42,9 +40,7 @@ def complex_arguments_fn( "123", 456, ], - field_with_default_via_field_annotation_before_nondefault_arg: Annotated[ - int, Field(1) - ], + field_with_default_via_field_annotation_before_nondefault_arg: Annotated[int, Field(1)], unannotated, my_model_a: SomeInputModelA, my_model_a_forward_ref: "SomeInputModelA", @@ -179,9 +175,7 @@ def func_with_str_types(str_or_list: str | list[str]): def test_skip_names(): """Test that skipped parameters are not included in the model""" - def func_with_many_params( - keep_this: int, skip_this: str, also_keep: float, also_skip: bool - ): + def func_with_many_params(keep_this: int, skip_this: str, also_keep: float, also_skip: bool): return keep_this, skip_this, also_keep, also_skip # Skip some parameters diff --git a/tests/server/fastmcp/test_integration.py b/tests/server/fastmcp/test_integration.py index 121492bc6..84c59cbda 100644 --- a/tests/server/fastmcp/test_integration.py +++ b/tests/server/fastmcp/test_integration.py @@ -130,11 +130,7 @@ async def sampling_tool(prompt: str, ctx: Context) -> str: # Request sampling from the client result = await ctx.session.create_message( - messages=[ - SamplingMessage( - role="user", content=TextContent(type="text", text=prompt) - ) - ], + messages=[SamplingMessage(role="user", content=TextContent(type="text", text=prompt))], max_tokens=100, temperature=0.7, ) @@ -278,11 +274,7 @@ def echo(message: str) -> str: def run_server(server_port: int) -> None: """Run the server.""" _, app = make_fastmcp_app() - server = uvicorn.Server( - config=uvicorn.Config( - app=app, host="127.0.0.1", port=server_port, log_level="error" - ) - ) + server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) print(f"Starting server on port {server_port}") server.run() @@ -290,11 +282,7 @@ def run_server(server_port: int) -> None: def run_everything_legacy_sse_http_server(server_port: int) -> None: """Run the comprehensive server with all features.""" _, app = make_everything_fastmcp_app() - server = uvicorn.Server( - config=uvicorn.Config( - app=app, host="127.0.0.1", port=server_port, log_level="error" - ) - ) + server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) print(f"Starting comprehensive server on port {server_port}") server.run() @@ -302,11 +290,7 @@ def run_everything_legacy_sse_http_server(server_port: int) -> None: def run_streamable_http_server(server_port: int) -> None: """Run the StreamableHTTP server.""" _, app = make_fastmcp_streamable_http_app() - server = uvicorn.Server( - config=uvicorn.Config( - app=app, host="127.0.0.1", port=server_port, log_level="error" - ) - ) + server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) print(f"Starting StreamableHTTP server on port {server_port}") server.run() @@ -314,11 +298,7 @@ def run_streamable_http_server(server_port: int) -> None: def run_everything_server(server_port: int) -> None: """Run the comprehensive StreamableHTTP server with all features.""" _, app = make_everything_fastmcp_streamable_http_app() - server = uvicorn.Server( - config=uvicorn.Config( - app=app, host="127.0.0.1", port=server_port, log_level="error" - ) - ) + server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) print(f"Starting comprehensive StreamableHTTP server on port {server_port}") server.run() @@ -326,11 +306,7 @@ def run_everything_server(server_port: int) -> None: def run_stateless_http_server(server_port: int) -> None: """Run the stateless StreamableHTTP server.""" _, app = make_fastmcp_stateless_http_app() - server = uvicorn.Server( - config=uvicorn.Config( - app=app, host="127.0.0.1", port=server_port, log_level="error" - ) - ) + server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) print(f"Starting stateless StreamableHTTP server on port {server_port}") server.run() @@ -369,9 +345,7 @@ def server(server_port: int) -> Generator[None, None, None]: @pytest.fixture() def streamable_http_server(http_server_port: int) -> Generator[None, None, None]: """Start the StreamableHTTP server in a separate process.""" - proc = multiprocessing.Process( - target=run_streamable_http_server, args=(http_server_port,), daemon=True - ) + proc = multiprocessing.Process(target=run_streamable_http_server, args=(http_server_port,), daemon=True) print("Starting StreamableHTTP server process") proc.start() @@ -388,9 +362,7 @@ def streamable_http_server(http_server_port: int) -> Generator[None, None, None] time.sleep(0.1) attempt += 1 else: - raise RuntimeError( - f"StreamableHTTP server failed to start after {max_attempts} attempts" - ) + raise RuntimeError(f"StreamableHTTP server failed to start after {max_attempts} attempts") yield @@ -427,9 +399,7 @@ def stateless_http_server( time.sleep(0.1) attempt += 1 else: - raise RuntimeError( - f"Stateless server failed to start after {max_attempts} attempts" - ) + raise RuntimeError(f"Stateless server failed to start after {max_attempts} attempts") yield @@ -459,9 +429,7 @@ async def test_fastmcp_without_auth(server: None, server_url: str) -> None: @pytest.mark.anyio -async def test_fastmcp_streamable_http( - streamable_http_server: None, http_server_url: str -) -> None: +async def test_fastmcp_streamable_http(streamable_http_server: None, http_server_url: str) -> None: """Test that FastMCP works with StreamableHTTP transport.""" # Connect to the server using StreamableHTTP async with streamablehttp_client(http_server_url + "/mcp") as ( @@ -484,9 +452,7 @@ async def test_fastmcp_streamable_http( @pytest.mark.anyio -async def test_fastmcp_stateless_streamable_http( - stateless_http_server: None, stateless_http_server_url: str -) -> None: +async def test_fastmcp_stateless_streamable_http(stateless_http_server: None, stateless_http_server_url: str) -> None: """Test that FastMCP works with stateless StreamableHTTP transport.""" # Connect to the server using StreamableHTTP async with streamablehttp_client(stateless_http_server_url + "/mcp") as ( @@ -562,9 +528,7 @@ def everything_server(everything_server_port: int) -> Generator[None, None, None time.sleep(0.1) attempt += 1 else: - raise RuntimeError( - f"Comprehensive server failed to start after {max_attempts} attempts" - ) + raise RuntimeError(f"Comprehensive server failed to start after {max_attempts} attempts") yield @@ -601,10 +565,7 @@ def everything_streamable_http_server( time.sleep(0.1) attempt += 1 else: - raise RuntimeError( - f"Comprehensive StreamableHTTP server failed to start after " - f"{max_attempts} attempts" - ) + raise RuntimeError(f"Comprehensive StreamableHTTP server failed to start after " f"{max_attempts} attempts") yield @@ -648,9 +609,7 @@ async def handle_generic_notification(self, message) -> None: await self.handle_tool_list_changed(message.root.params) -async def call_all_mcp_features( - session: ClientSession, collector: NotificationCollector -) -> None: +async def call_all_mcp_features(session: ClientSession, collector: NotificationCollector) -> None: """ Test all MCP features using the provided session. @@ -680,9 +639,7 @@ async def call_all_mcp_features( # Test progress callback functionality progress_updates = [] - async def progress_callback( - progress: float, total: float | None, message: str | None - ) -> None: + async def progress_callback(progress: float, total: float | None, message: str | None) -> None: """Collect progress updates for testing (async version).""" progress_updates.append((progress, total, message)) print(f"Progress: {progress}/{total} - {message}") @@ -726,19 +683,12 @@ async def progress_callback( # Verify we received log messages from the sampling tool assert len(collector.log_messages) > 0 - assert any( - "Requesting sampling for prompt" in msg.data for msg in collector.log_messages - ) - assert any( - "Received sampling result from model" in msg.data - for msg in collector.log_messages - ) + assert any("Requesting sampling for prompt" in msg.data for msg in collector.log_messages) + assert any("Received sampling result from model" in msg.data for msg in collector.log_messages) # 4. Test notification tool notification_message = "test_notifications" - notification_result = await session.call_tool( - "notification_tool", {"message": notification_message} - ) + notification_result = await session.call_tool("notification_tool", {"message": notification_message}) assert len(notification_result.content) == 1 assert isinstance(notification_result.content[0], TextContent) assert "Sent notifications and logs" in notification_result.content[0].text @@ -773,36 +723,24 @@ async def progress_callback( # 2. Dynamic resource resource_category = "test" - dynamic_content = await session.read_resource( - AnyUrl(f"resource://dynamic/{resource_category}") - ) + dynamic_content = await session.read_resource(AnyUrl(f"resource://dynamic/{resource_category}")) assert isinstance(dynamic_content, ReadResourceResult) assert len(dynamic_content.contents) == 1 assert isinstance(dynamic_content.contents[0], TextResourceContents) - assert ( - f"Dynamic resource content for category: {resource_category}" - in dynamic_content.contents[0].text - ) + assert f"Dynamic resource content for category: {resource_category}" in dynamic_content.contents[0].text # 3. Template resource resource_id = "456" - template_content = await session.read_resource( - AnyUrl(f"resource://template/{resource_id}/data") - ) + template_content = await session.read_resource(AnyUrl(f"resource://template/{resource_id}/data")) assert isinstance(template_content, ReadResourceResult) assert len(template_content.contents) == 1 assert isinstance(template_content.contents[0], TextResourceContents) - assert ( - f"Template resource data for ID: {resource_id}" - in template_content.contents[0].text - ) + assert f"Template resource data for ID: {resource_id}" in template_content.contents[0].text # Test prompts # 1. Simple prompt prompts = await session.list_prompts() - simple_prompt = next( - (p for p in prompts.prompts if p.name == "simple_prompt"), None - ) + simple_prompt = next((p for p in prompts.prompts if p.name == "simple_prompt"), None) assert simple_prompt is not None prompt_topic = "AI" @@ -812,16 +750,12 @@ async def progress_callback( # The actual message structure depends on the prompt implementation # 2. Complex prompt - complex_prompt = next( - (p for p in prompts.prompts if p.name == "complex_prompt"), None - ) + complex_prompt = next((p for p in prompts.prompts if p.name == "complex_prompt"), None) assert complex_prompt is not None query = "What is AI?" context = "technical" - complex_result = await session.get_prompt( - "complex_prompt", {"user_query": query, "context": context} - ) + complex_result = await session.get_prompt("complex_prompt", {"user_query": query, "context": context}) assert isinstance(complex_result, GetPromptResult) assert len(complex_result.messages) >= 1 @@ -837,9 +771,7 @@ async def progress_callback( print(f"Received headers: {headers_data}") # Test 6: Call tool that returns full context - context_result = await session.call_tool( - "echo_context", {"custom_request_id": "test-123"} - ) + context_result = await session.call_tool("echo_context", {"custom_request_id": "test-123"}) assert len(context_result.content) == 1 assert isinstance(context_result.content[0], TextContent) @@ -871,9 +803,7 @@ async def sampling_callback( @pytest.mark.anyio -async def test_fastmcp_all_features_sse( - everything_server: None, everything_server_url: str -) -> None: +async def test_fastmcp_all_features_sse(everything_server: None, everything_server_url: str) -> None: """Test all MCP features work correctly with SSE transport.""" # Create notification collector diff --git a/tests/server/fastmcp/test_server.py b/tests/server/fastmcp/test_server.py index b817761ea..62fd4171e 100644 --- a/tests/server/fastmcp/test_server.py +++ b/tests/server/fastmcp/test_server.py @@ -58,9 +58,7 @@ async def test_sse_app_with_mount_path(self): """Test SSE app creation with different mount paths.""" # Test with default mount path mcp = FastMCP() - with patch.object( - mcp, "_normalize_path", return_value="/messages/" - ) as mock_normalize: + with patch.object(mcp, "_normalize_path", return_value="/messages/") as mock_normalize: mcp.sse_app() # Verify _normalize_path was called with correct args mock_normalize.assert_called_once_with("/", "/messages/") @@ -68,18 +66,14 @@ async def test_sse_app_with_mount_path(self): # Test with custom mount path in settings mcp = FastMCP() mcp.settings.mount_path = "/custom" - with patch.object( - mcp, "_normalize_path", return_value="/custom/messages/" - ) as mock_normalize: + with patch.object(mcp, "_normalize_path", return_value="/custom/messages/") as mock_normalize: mcp.sse_app() # Verify _normalize_path was called with correct args mock_normalize.assert_called_once_with("/custom", "/messages/") # Test with mount_path parameter mcp = FastMCP() - with patch.object( - mcp, "_normalize_path", return_value="/param/messages/" - ) as mock_normalize: + with patch.object(mcp, "_normalize_path", return_value="/param/messages/") as mock_normalize: mcp.sse_app(mount_path="/param") # Verify _normalize_path was called with correct args mock_normalize.assert_called_once_with("/param", "/messages/") @@ -102,9 +96,7 @@ async def test_starlette_routes_with_mount_path(self): # Verify path values assert sse_routes[0].path == "/sse", "SSE route path should be /sse" - assert ( - mount_routes[0].path == "/messages" - ), "Mount route path should be /messages" + assert mount_routes[0].path == "/messages", "Mount route path should be /messages" # Test with mount path as parameter mcp = FastMCP() @@ -120,20 +112,14 @@ async def test_starlette_routes_with_mount_path(self): # Verify path values assert sse_routes[0].path == "/sse", "SSE route path should be /sse" - assert ( - mount_routes[0].path == "/messages" - ), "Mount route path should be /messages" + assert mount_routes[0].path == "/messages", "Mount route path should be /messages" @pytest.mark.anyio async def test_non_ascii_description(self): """Test that FastMCP handles non-ASCII characters in descriptions correctly""" mcp = FastMCP() - @mcp.tool( - description=( - "🌟 This tool uses emojis and UTF-8 characters: á é í ó ú ñ 漢字 🎉" - ) - ) + @mcp.tool(description=("🌟 This tool uses emojis and UTF-8 characters: á é í ó ú ñ 漢字 🎉")) def hello_world(name: str = "世界") -> str: return f"¡Hola, {name}! 👋" @@ -186,9 +172,7 @@ def get_data(x: str) -> str: async def test_add_resource_decorator_incorrect_usage(self): mcp = FastMCP() - with pytest.raises( - TypeError, match="The @resource decorator was used incorrectly" - ): + with pytest.raises(TypeError, match="The @resource decorator was used incorrectly"): @mcp.resource # Missing parentheses #type: ignore def get_data(x: str) -> str: @@ -369,9 +353,7 @@ async def test_text_resource(self): def get_text(): return "Hello, world!" - resource = FunctionResource( - uri=AnyUrl("resource://test"), name="test", fn=get_text - ) + resource = FunctionResource(uri=AnyUrl("resource://test"), name="test", fn=get_text) mcp.add_resource(resource) async with client_session(mcp._mcp_server) as client: @@ -407,9 +389,7 @@ async def test_file_resource_text(self, tmp_path: Path): text_file = tmp_path / "test.txt" text_file.write_text("Hello from file!") - resource = FileResource( - uri=AnyUrl("file://test.txt"), name="test.txt", path=text_file - ) + resource = FileResource(uri=AnyUrl("file://test.txt"), name="test.txt", path=text_file) mcp.add_resource(resource) async with client_session(mcp._mcp_server) as client: @@ -436,10 +416,7 @@ async def test_file_resource_binary(self, tmp_path: Path): async with client_session(mcp._mcp_server) as client: result = await client.read_resource(AnyUrl("file://test.bin")) assert isinstance(result.contents[0], BlobResourceContents) - assert ( - result.contents[0].blob - == base64.b64encode(b"Binary file data").decode() - ) + assert result.contents[0].blob == base64.b64encode(b"Binary file data").decode() @pytest.mark.anyio async def test_function_resource(self): @@ -528,9 +505,7 @@ def get_data(org: str, repo: str) -> str: return f"Data for {org}/{repo}" async with client_session(mcp._mcp_server) as client: - result = await client.read_resource( - AnyUrl("resource://cursor/fastmcp/data") - ) + result = await client.read_resource(AnyUrl("resource://cursor/fastmcp/data")) assert isinstance(result.contents[0], TextResourceContents) assert result.contents[0].text == "Data for cursor/fastmcp" diff --git a/tests/server/fastmcp/test_tool_manager.py b/tests/server/fastmcp/test_tool_manager.py index b45c7ac38..7954e1729 100644 --- a/tests/server/fastmcp/test_tool_manager.py +++ b/tests/server/fastmcp/test_tool_manager.py @@ -147,9 +147,7 @@ def test_add_lambda(self): def test_add_lambda_with_no_name(self): manager = ToolManager() - with pytest.raises( - ValueError, match="You must provide a name for lambda functions" - ): + with pytest.raises(ValueError, match="You must provide a name for lambda functions"): manager.add_tool(lambda x: x) def test_warn_on_duplicate_tools(self, caplog): @@ -346,9 +344,7 @@ def tool_without_context(x: int) -> str: tool = manager.add_tool(tool_without_context) assert tool.context_kwarg is None - def tool_with_parametrized_context( - x: int, ctx: Context[ServerSessionT, LifespanContextT, RequestT] - ) -> str: + def tool_with_parametrized_context(x: int, ctx: Context[ServerSessionT, LifespanContextT, RequestT]) -> str: return str(x) tool = manager.add_tool(tool_with_parametrized_context) diff --git a/tests/server/test_lowlevel_tool_annotations.py b/tests/server/test_lowlevel_tool_annotations.py index e9eff9ed0..2eb3b7ddb 100644 --- a/tests/server/test_lowlevel_tool_annotations.py +++ b/tests/server/test_lowlevel_tool_annotations.py @@ -10,13 +10,7 @@ from mcp.server.session import ServerSession from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder -from mcp.types import ( - ClientResult, - ServerNotification, - ServerRequest, - Tool, - ToolAnnotations, -) +from mcp.types import ClientResult, ServerNotification, ServerRequest, Tool, ToolAnnotations @pytest.mark.anyio @@ -45,18 +39,12 @@ async def list_tools(): ) ] - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - SessionMessage - ](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - SessionMessage - ](10) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) # Message handler for client async def message_handler( - message: RequestResponder[ServerRequest, ClientResult] - | ServerNotification - | Exception, + message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, ) -> None: if isinstance(message, Exception): raise message diff --git a/tests/server/test_read_resource.py b/tests/server/test_read_resource.py index 469eef857..91f6ef8c8 100644 --- a/tests/server/test_read_resource.py +++ b/tests/server/test_read_resource.py @@ -56,11 +56,7 @@ async def test_read_resource_binary(temp_file: Path): @server.read_resource() async def read_resource(uri: AnyUrl) -> Iterable[ReadResourceContents]: - return [ - ReadResourceContents( - content=b"Hello World", mime_type="application/octet-stream" - ) - ] + return [ReadResourceContents(content=b"Hello World", mime_type="application/octet-stream")] # Get the handler directly from the server handler = server.request_handlers[types.ReadResourceRequest] diff --git a/tests/server/test_session.py b/tests/server/test_session.py index 1375df12f..69321f87c 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -20,18 +20,12 @@ @pytest.mark.anyio async def test_server_session_initialize(): - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - SessionMessage - ](1) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - SessionMessage - ](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) # Create a message handler to catch exceptions async def message_handler( - message: RequestResponder[types.ServerRequest, types.ClientResult] - | types.ServerNotification - | Exception, + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: if isinstance(message, Exception): raise message @@ -54,9 +48,7 @@ async def run_server(): if isinstance(message, Exception): raise message - if isinstance(message, ClientNotification) and isinstance( - message.root, InitializedNotification - ): + if isinstance(message, ClientNotification) and isinstance(message.root, InitializedNotification): received_initialized = True return @@ -111,12 +103,8 @@ async def list_resources(): @pytest.mark.anyio async def test_server_session_initialize_with_older_protocol_version(): """Test that server accepts and responds with older protocol (2024-11-05).""" - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - SessionMessage - ](1) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - SessionMessage | Exception - ](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) received_initialized = False received_protocol_version = None @@ -137,9 +125,7 @@ async def run_server(): if isinstance(message, Exception): raise message - if isinstance(message, types.ClientNotification) and isinstance( - message.root, InitializedNotification - ): + if isinstance(message, types.ClientNotification) and isinstance(message.root, InitializedNotification): received_initialized = True return @@ -157,9 +143,7 @@ async def mock_client(): params=types.InitializeRequestParams( protocolVersion="2024-11-05", capabilities=types.ClientCapabilities(), - clientInfo=types.Implementation( - name="test-client", version="1.0.0" - ), + clientInfo=types.Implementation(name="test-client", version="1.0.0"), ).model_dump(by_alias=True, mode="json", exclude_none=True), ) ) diff --git a/tests/server/test_stdio.py b/tests/server/test_stdio.py index c546a7167..2d1850b73 100644 --- a/tests/server/test_stdio.py +++ b/tests/server/test_stdio.py @@ -22,9 +22,10 @@ async def test_stdio_server(): stdin.write(message.model_dump_json(by_alias=True, exclude_none=True) + "\n") stdin.seek(0) - async with stdio_server( - stdin=anyio.AsyncFile(stdin), stdout=anyio.AsyncFile(stdout) - ) as (read_stream, write_stream): + async with stdio_server(stdin=anyio.AsyncFile(stdin), stdout=anyio.AsyncFile(stdout)) as ( + read_stream, + write_stream, + ): received_messages = [] async with read_stream: async for message in read_stream: @@ -36,12 +37,8 @@ async def test_stdio_server(): # Verify received messages assert len(received_messages) == 2 - assert received_messages[0] == JSONRPCMessage( - root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") - ) - assert received_messages[1] == JSONRPCMessage( - root=JSONRPCResponse(jsonrpc="2.0", id=2, result={}) - ) + assert received_messages[0] == JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")) + assert received_messages[1] == JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})) # Test sending responses from the server responses = [ @@ -58,13 +55,7 @@ async def test_stdio_server(): output_lines = stdout.readlines() assert len(output_lines) == 2 - received_responses = [ - JSONRPCMessage.model_validate_json(line.strip()) for line in output_lines - ] + received_responses = [JSONRPCMessage.model_validate_json(line.strip()) for line in output_lines] assert len(received_responses) == 2 - assert received_responses[0] == JSONRPCMessage( - root=JSONRPCRequest(jsonrpc="2.0", id=3, method="ping") - ) - assert received_responses[1] == JSONRPCMessage( - root=JSONRPCResponse(jsonrpc="2.0", id=4, result={}) - ) + assert received_responses[0] == JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=3, method="ping")) + assert received_responses[1] == JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=4, result={})) diff --git a/tests/server/test_streamable_http_manager.py b/tests/server/test_streamable_http_manager.py index 32782e458..65828b63b 100644 --- a/tests/server/test_streamable_http_manager.py +++ b/tests/server/test_streamable_http_manager.py @@ -22,10 +22,7 @@ async def test_run_can_only_be_called_once(): async with manager.run(): pass - assert ( - "StreamableHTTPSessionManager .run() can only be called once per instance" - in str(excinfo.value) - ) + assert "StreamableHTTPSessionManager .run() can only be called once per instance" in str(excinfo.value) @pytest.mark.anyio @@ -51,10 +48,7 @@ async def try_run(): # One should succeed, one should fail assert len(errors) == 1 - assert ( - "StreamableHTTPSessionManager .run() can only be called once per instance" - in str(errors[0]) - ) + assert "StreamableHTTPSessionManager .run() can only be called once per instance" in str(errors[0]) @pytest.mark.anyio @@ -76,6 +70,4 @@ async def send(message): with pytest.raises(RuntimeError) as excinfo: await manager.handle_request(scope, receive, send) - assert "Task group is not initialized. Make sure to use run()." in str( - excinfo.value - ) + assert "Task group is not initialized. Make sure to use run()." in str(excinfo.value) diff --git a/tests/shared/test_progress_notifications.py b/tests/shared/test_progress_notifications.py index 1e0409e14..08bcb2662 100644 --- a/tests/shared/test_progress_notifications.py +++ b/tests/shared/test_progress_notifications.py @@ -22,12 +22,8 @@ async def test_bidirectional_progress_notifications(): """Test that both client and server can send progress notifications.""" # Create memory streams for client/server - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - SessionMessage - ](5) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - SessionMessage - ](5) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](5) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](5) # Run a server session so we can send progress updates in tool async def run_server(): @@ -134,9 +130,7 @@ async def handle_call_tool(name: str, arguments: dict | None) -> list: # Client message handler to store progress notifications async def handle_client_message( - message: RequestResponder[types.ServerRequest, types.ClientResult] - | types.ServerNotification - | Exception, + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: if isinstance(message, Exception): raise message @@ -172,9 +166,7 @@ async def handle_client_message( await client_session.list_tools() # Call test_tool with progress token - await client_session.call_tool( - "test_tool", {"_meta": {"progressToken": client_progress_token}} - ) + await client_session.call_tool("test_tool", {"_meta": {"progressToken": client_progress_token}}) # Send progress notifications from client to server await client_session.send_progress_notification( @@ -221,12 +213,8 @@ async def handle_client_message( async def test_progress_context_manager(): """Test client using progress context manager for sending progress notifications.""" # Create memory streams for client/server - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - SessionMessage - ](5) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - SessionMessage - ](5) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](5) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](5) # Track progress updates server_progress_updates = [] @@ -270,9 +258,7 @@ async def run_server(): # Client message handler async def handle_client_message( - message: RequestResponder[types.ServerRequest, types.ClientResult] - | types.ServerNotification - | Exception, + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: if isinstance(message, Exception): raise message diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index eb4e004ae..864e0d1b4 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -90,9 +90,7 @@ async def make_request(client_session): ClientRequest( types.CallToolRequest( method="tools/call", - params=types.CallToolRequestParams( - name="slow_tool", arguments={} - ), + params=types.CallToolRequestParams(name="slow_tool", arguments={}), ) ), types.CallToolResult, @@ -103,9 +101,7 @@ async def make_request(client_session): assert "Request cancelled" in str(e) ev_cancelled.set() - async with create_connected_server_and_client_session( - make_server() - ) as client_session: + async with create_connected_server_and_client_session(make_server()) as client_session: async with anyio.create_task_group() as tg: tg.start_soon(make_request, client_session) diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 78bbbb235..8a3fdc435 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -58,11 +58,7 @@ async def handle_read_resource(uri: AnyUrl) -> str | bytes: await anyio.sleep(2.0) return f"Slow response from {uri.host}" - raise McpError( - error=ErrorData( - code=404, message="OOPS! no resource with that URI was found" - ) - ) + raise McpError(error=ErrorData(code=404, message="OOPS! no resource with that URI was found")) @self.list_tools() async def handle_list_tools() -> list[Tool]: @@ -86,12 +82,8 @@ def make_server_app() -> Starlette: server = ServerTest() async def handle_sse(request: Request) -> Response: - async with sse.connect_sse( - request.scope, request.receive, request._send - ) as streams: - await server.run( - streams[0], streams[1], server.create_initialization_options() - ) + async with sse.connect_sse(request.scope, request.receive, request._send) as streams: + await server.run(streams[0], streams[1], server.create_initialization_options()) return Response() app = Starlette( @@ -106,11 +98,7 @@ async def handle_sse(request: Request) -> Response: def run_server(server_port: int) -> None: app = make_server_app() - server = uvicorn.Server( - config=uvicorn.Config( - app=app, host="127.0.0.1", port=server_port, log_level="error" - ) - ) + server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) print(f"starting server on {server_port}") server.run() @@ -122,9 +110,7 @@ def run_server(server_port: int) -> None: @pytest.fixture() def server(server_port: int) -> Generator[None, None, None]: - proc = multiprocessing.Process( - target=run_server, kwargs={"server_port": server_port}, daemon=True - ) + proc = multiprocessing.Process(target=run_server, kwargs={"server_port": server_port}, daemon=True) print("starting process") proc.start() @@ -169,10 +155,7 @@ async def test_raw_sse_connection(http_client: httpx.AsyncClient) -> None: async def connection_test() -> None: async with http_client.stream("GET", "/sse") as response: assert response.status_code == 200 - assert ( - response.headers["content-type"] - == "text/event-stream; charset=utf-8" - ) + assert response.headers["content-type"] == "text/event-stream; charset=utf-8" line_number = 0 async for line in response.aiter_lines(): @@ -204,9 +187,7 @@ async def test_sse_client_basic_connection(server: None, server_url: str) -> Non @pytest.fixture -async def initialized_sse_client_session( - server, server_url: str -) -> AsyncGenerator[ClientSession, None]: +async def initialized_sse_client_session(server, server_url: str) -> AsyncGenerator[ClientSession, None]: async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams: async with ClientSession(*streams) as session: await session.initialize() @@ -234,9 +215,7 @@ async def test_sse_client_exception_handling( @pytest.mark.anyio -@pytest.mark.skip( - "this test highlights a possible bug in SSE read timeout exception handling" -) +@pytest.mark.skip("this test highlights a possible bug in SSE read timeout exception handling") async def test_sse_client_timeout( initialized_sse_client_session: ClientSession, ) -> None: @@ -258,11 +237,7 @@ async def test_sse_client_timeout( def run_mounted_server(server_port: int) -> None: app = make_server_app() main_app = Starlette(routes=[Mount("/mounted_app", app=app)]) - server = uvicorn.Server( - config=uvicorn.Config( - app=main_app, host="127.0.0.1", port=server_port, log_level="error" - ) - ) + server = uvicorn.Server(config=uvicorn.Config(app=main_app, host="127.0.0.1", port=server_port, log_level="error")) print(f"starting server on {server_port}") server.run() @@ -274,9 +249,7 @@ def run_mounted_server(server_port: int) -> None: @pytest.fixture() def mounted_server(server_port: int) -> Generator[None, None, None]: - proc = multiprocessing.Process( - target=run_mounted_server, kwargs={"server_port": server_port}, daemon=True - ) + proc = multiprocessing.Process(target=run_mounted_server, kwargs={"server_port": server_port}, daemon=True) print("starting process") proc.start() @@ -306,9 +279,7 @@ def mounted_server(server_port: int) -> Generator[None, None, None]: @pytest.mark.anyio -async def test_sse_client_basic_connection_mounted_app( - mounted_server: None, server_url: str -) -> None: +async def test_sse_client_basic_connection_mounted_app(mounted_server: None, server_url: str) -> None: async with sse_client(server_url + "/mounted_app/sse") as streams: async with ClientSession(*streams) as session: # Test initialization @@ -370,12 +341,8 @@ def run_context_server(server_port: int) -> None: context_server = RequestContextServer() async def handle_sse(request: Request) -> Response: - async with sse.connect_sse( - request.scope, request.receive, request._send - ) as streams: - await context_server.run( - streams[0], streams[1], context_server.create_initialization_options() - ) + async with sse.connect_sse(request.scope, request.receive, request._send) as streams: + await context_server.run(streams[0], streams[1], context_server.create_initialization_options()) return Response() app = Starlette( @@ -385,11 +352,7 @@ async def handle_sse(request: Request) -> Response: ] ) - server = uvicorn.Server( - config=uvicorn.Config( - app=app, host="127.0.0.1", port=server_port, log_level="error" - ) - ) + server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) print(f"starting context server on {server_port}") server.run() @@ -397,9 +360,7 @@ async def handle_sse(request: Request) -> Response: @pytest.fixture() def context_server(server_port: int) -> Generator[None, None, None]: """Fixture that provides a server with request context capture""" - proc = multiprocessing.Process( - target=run_context_server, kwargs={"server_port": server_port}, daemon=True - ) + proc = multiprocessing.Process(target=run_context_server, kwargs={"server_port": server_port}, daemon=True) print("starting context server process") proc.start() @@ -416,9 +377,7 @@ def context_server(server_port: int) -> Generator[None, None, None]: time.sleep(0.1) attempt += 1 else: - raise RuntimeError( - f"Context server failed to start after {max_attempts} attempts" - ) + raise RuntimeError(f"Context server failed to start after {max_attempts} attempts") yield @@ -430,9 +389,7 @@ def context_server(server_port: int) -> Generator[None, None, None]: @pytest.mark.anyio -async def test_request_context_propagation( - context_server: None, server_url: str -) -> None: +async def test_request_context_propagation(context_server: None, server_url: str) -> None: """Test that request context is properly propagated through SSE transport.""" # Test with custom headers custom_headers = { @@ -456,11 +413,7 @@ async def test_request_context_propagation( # Parse the JSON response assert len(tool_result.content) == 1 - headers_data = json.loads( - tool_result.content[0].text - if tool_result.content[0].type == "text" - else "{}" - ) + headers_data = json.loads(tool_result.content[0].text if tool_result.content[0].type == "text" else "{}") # Verify headers were propagated assert headers_data.get("authorization") == "Bearer test-token" @@ -485,15 +438,11 @@ async def test_request_context_isolation(context_server: None, server_url: str) await session.initialize() # Call the tool that echoes context - tool_result = await session.call_tool( - "echo_context", {"request_id": f"request-{i}"} - ) + tool_result = await session.call_tool("echo_context", {"request_id": f"request-{i}"}) assert len(tool_result.content) == 1 context_data = json.loads( - tool_result.content[0].text - if tool_result.content[0].type == "text" - else "{}" + tool_result.content[0].text if tool_result.content[0].type == "text" else "{}" ) contexts.append(context_data) diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 5cf346e1a..615e68efc 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -72,9 +72,7 @@ def __init__(self): self._events: list[tuple[StreamId, EventId, types.JSONRPCMessage]] = [] self._event_id_counter = 0 - async def store_event( - self, stream_id: StreamId, message: types.JSONRPCMessage - ) -> EventId: + async def store_event(self, stream_id: StreamId, message: types.JSONRPCMessage) -> EventId: """Store an event and return its ID.""" self._event_id_counter += 1 event_id = str(self._event_id_counter) @@ -156,9 +154,7 @@ async def handle_call_tool(name: str, args: dict) -> list[TextContent]: # When the tool is called, send a notification to test GET stream if name == "test_tool_with_standalone_notification": - await ctx.session.send_resource_updated( - uri=AnyUrl("http://test_resource") - ) + await ctx.session.send_resource_updated(uri=AnyUrl("http://test_resource")) return [TextContent(type="text", text=f"Called {name}")] elif name == "long_running_with_checkpoints": @@ -189,9 +185,7 @@ async def handle_call_tool(name: str, args: dict) -> list[TextContent]: messages=[ types.SamplingMessage( role="user", - content=types.TextContent( - type="text", text="Server needs client sampling" - ), + content=types.TextContent(type="text", text="Server needs client sampling"), ) ], max_tokens=100, @@ -199,11 +193,7 @@ async def handle_call_tool(name: str, args: dict) -> list[TextContent]: ) # Return the sampling result in the tool response - response = ( - sampling_result.content.text - if sampling_result.content.type == "text" - else None - ) + response = sampling_result.content.text if sampling_result.content.type == "text" else None return [ TextContent( type="text", @@ -214,9 +204,7 @@ async def handle_call_tool(name: str, args: dict) -> list[TextContent]: return [TextContent(type="text", text=f"Called {name}")] -def create_app( - is_json_response_enabled=False, event_store: EventStore | None = None -) -> Starlette: +def create_app(is_json_response_enabled=False, event_store: EventStore | None = None) -> Starlette: """Create a Starlette application for testing using the session manager. Args: @@ -245,9 +233,7 @@ def create_app( return app -def run_server( - port: int, is_json_response_enabled=False, event_store: EventStore | None = None -) -> None: +def run_server(port: int, is_json_response_enabled=False, event_store: EventStore | None = None) -> None: """Run the test server. Args: @@ -300,9 +286,7 @@ def json_server_port() -> int: @pytest.fixture def basic_server(basic_server_port: int) -> Generator[None, None, None]: """Start a basic server.""" - proc = multiprocessing.Process( - target=run_server, kwargs={"port": basic_server_port}, daemon=True - ) + proc = multiprocessing.Process(target=run_server, kwargs={"port": basic_server_port}, daemon=True) proc.start() # Wait for server to be running @@ -778,9 +762,7 @@ async def test_streamablehttp_client_basic_connection(basic_server, basic_server @pytest.mark.anyio async def test_streamablehttp_client_resource_read(initialized_client_session): """Test client resource read functionality.""" - response = await initialized_client_session.read_resource( - uri=AnyUrl("foobar://test-resource") - ) + response = await initialized_client_session.read_resource(uri=AnyUrl("foobar://test-resource")) assert len(response.contents) == 1 assert response.contents[0].uri == AnyUrl("foobar://test-resource") assert response.contents[0].text == "Read test-resource" @@ -805,17 +787,13 @@ async def test_streamablehttp_client_tool_invocation(initialized_client_session) async def test_streamablehttp_client_error_handling(initialized_client_session): """Test error handling in client.""" with pytest.raises(McpError) as exc_info: - await initialized_client_session.read_resource( - uri=AnyUrl("unknown://test-error") - ) + await initialized_client_session.read_resource(uri=AnyUrl("unknown://test-error")) assert exc_info.value.error.code == 0 assert "Unknown resource: unknown://test-error" in exc_info.value.error.message @pytest.mark.anyio -async def test_streamablehttp_client_session_persistence( - basic_server, basic_server_url -): +async def test_streamablehttp_client_session_persistence(basic_server, basic_server_url): """Test that session ID persists across requests.""" async with streamablehttp_client(f"{basic_server_url}/mcp") as ( read_stream, @@ -843,9 +821,7 @@ async def test_streamablehttp_client_session_persistence( @pytest.mark.anyio -async def test_streamablehttp_client_json_response( - json_response_server, json_server_url -): +async def test_streamablehttp_client_json_response(json_response_server, json_server_url): """Test client with JSON response mode.""" async with streamablehttp_client(f"{json_server_url}/mcp") as ( read_stream, @@ -882,9 +858,7 @@ async def test_streamablehttp_client_get_stream(basic_server, basic_server_url): # Define message handler to capture notifications async def message_handler( - message: RequestResponder[types.ServerRequest, types.ClientResult] - | types.ServerNotification - | Exception, + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: if isinstance(message, types.ServerNotification): notifications_received.append(message) @@ -894,9 +868,7 @@ async def message_handler( write_stream, _, ): - async with ClientSession( - read_stream, write_stream, message_handler=message_handler - ) as session: + async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: # Initialize the session - this triggers the GET stream setup result = await session.initialize() assert isinstance(result, InitializeResult) @@ -914,15 +886,11 @@ async def message_handler( assert str(notif.root.params.uri) == "http://test_resource/" resource_update_found = True - assert ( - resource_update_found - ), "ResourceUpdatedNotification not received via GET stream" + assert resource_update_found, "ResourceUpdatedNotification not received via GET stream" @pytest.mark.anyio -async def test_streamablehttp_client_session_termination( - basic_server, basic_server_url -): +async def test_streamablehttp_client_session_termination(basic_server, basic_server_url): """Test client session termination functionality.""" captured_session_id = None @@ -963,9 +931,7 @@ async def test_streamablehttp_client_session_termination( @pytest.mark.anyio -async def test_streamablehttp_client_session_termination_204( - basic_server, basic_server_url, monkeypatch -): +async def test_streamablehttp_client_session_termination_204(basic_server, basic_server_url, monkeypatch): """Test client session termination functionality with a 204 response. This test patches the httpx client to return a 204 response for DELETEs. @@ -1040,9 +1006,7 @@ async def test_streamablehttp_client_resumption(event_server): tool_started = False async def message_handler( - message: RequestResponder[types.ServerRequest, types.ClientResult] - | types.ServerNotification - | Exception, + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: if isinstance(message, types.ServerNotification): captured_notifications.append(message) @@ -1062,9 +1026,7 @@ async def on_resumption_token_update(token: str) -> None: write_stream, get_session_id, ): - async with ClientSession( - read_stream, write_stream, message_handler=message_handler - ) as session: + async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: # Initialize the session result = await session.initialize() assert isinstance(result, InitializeResult) @@ -1082,9 +1044,7 @@ async def run_tool(): types.ClientRequest( types.CallToolRequest( method="tools/call", - params=types.CallToolRequestParams( - name="long_running_with_checkpoints", arguments={} - ), + params=types.CallToolRequestParams(name="long_running_with_checkpoints", arguments={}), ) ), types.CallToolResult, @@ -1114,9 +1074,7 @@ async def run_tool(): write_stream, _, ): - async with ClientSession( - read_stream, write_stream, message_handler=message_handler - ) as session: + async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: # Don't initialize - just use the existing session # Resume the tool with the resumption token @@ -1129,9 +1087,7 @@ async def run_tool(): types.ClientRequest( types.CallToolRequest( method="tools/call", - params=types.CallToolRequestParams( - name="long_running_with_checkpoints", arguments={} - ), + params=types.CallToolRequestParams(name="long_running_with_checkpoints", arguments={}), ) ), types.CallToolResult, @@ -1149,14 +1105,11 @@ async def run_tool(): # Should not have the first notification # Check that "Tool started" notification isn't repeated when resuming assert not any( - isinstance(n.root, types.LoggingMessageNotification) - and n.root.params.data == "Tool started" + isinstance(n.root, types.LoggingMessageNotification) and n.root.params.data == "Tool started" for n in captured_notifications ) # there is no intersection between pre and post notifications - assert not any( - n in captured_notifications_pre for n in captured_notifications - ) + assert not any(n in captured_notifications_pre for n in captured_notifications) @pytest.mark.anyio @@ -1175,11 +1128,7 @@ async def sampling_callback( nonlocal sampling_callback_invoked, captured_message_params sampling_callback_invoked = True captured_message_params = params - message_received = ( - params.messages[0].content.text - if params.messages[0].content.type == "text" - else None - ) + message_received = params.messages[0].content.text if params.messages[0].content.type == "text" else None return types.CreateMessageResult( role="assistant", @@ -1212,19 +1161,13 @@ async def sampling_callback( # Verify the tool result contains the expected content assert len(tool_result.content) == 1 assert tool_result.content[0].type == "text" - assert ( - "Response from sampling: Received message from server" - in tool_result.content[0].text - ) + assert "Response from sampling: Received message from server" in tool_result.content[0].text # Verify sampling callback was invoked assert sampling_callback_invoked assert captured_message_params is not None assert len(captured_message_params.messages) == 1 - assert ( - captured_message_params.messages[0].content.text - == "Server needs client sampling" - ) + assert captured_message_params.messages[0].content.text == "Server needs client sampling" # Context-aware server implementation for testing request context propagation @@ -1325,9 +1268,7 @@ def run_context_aware_server(port: int): @pytest.fixture def context_aware_server(basic_server_port: int) -> Generator[None, None, None]: """Start the context-aware server in a separate process.""" - proc = multiprocessing.Process( - target=run_context_aware_server, args=(basic_server_port,), daemon=True - ) + proc = multiprocessing.Process(target=run_context_aware_server, args=(basic_server_port,), daemon=True) proc.start() # Wait for server to be running @@ -1342,9 +1283,7 @@ def context_aware_server(basic_server_port: int) -> Generator[None, None, None]: time.sleep(0.1) attempt += 1 else: - raise RuntimeError( - f"Context-aware server failed to start after {max_attempts} attempts" - ) + raise RuntimeError(f"Context-aware server failed to start after {max_attempts} attempts") yield @@ -1355,9 +1294,7 @@ def context_aware_server(basic_server_port: int) -> Generator[None, None, None]: @pytest.mark.anyio -async def test_streamablehttp_request_context_propagation( - context_aware_server: None, basic_server_url: str -) -> None: +async def test_streamablehttp_request_context_propagation(context_aware_server: None, basic_server_url: str) -> None: """Test that request context is properly propagated through StreamableHTTP.""" custom_headers = { "Authorization": "Bearer test-token", @@ -1365,9 +1302,11 @@ async def test_streamablehttp_request_context_propagation( "X-Trace-Id": "trace-123", } - async with streamablehttp_client( - f"{basic_server_url}/mcp", headers=custom_headers - ) as (read_stream, write_stream, _): + async with streamablehttp_client(f"{basic_server_url}/mcp", headers=custom_headers) as ( + read_stream, + write_stream, + _, + ): async with ClientSession(read_stream, write_stream) as session: result = await session.initialize() assert isinstance(result, InitializeResult) @@ -1388,9 +1327,7 @@ async def test_streamablehttp_request_context_propagation( @pytest.mark.anyio -async def test_streamablehttp_request_context_isolation( - context_aware_server: None, basic_server_url: str -) -> None: +async def test_streamablehttp_request_context_isolation(context_aware_server: None, basic_server_url: str) -> None: """Test that request contexts are isolated between StreamableHTTP clients.""" contexts = [] @@ -1402,16 +1339,12 @@ async def test_streamablehttp_request_context_isolation( "Authorization": f"Bearer token-{i}", } - async with streamablehttp_client( - f"{basic_server_url}/mcp", headers=headers - ) as (read_stream, write_stream, _): + async with streamablehttp_client(f"{basic_server_url}/mcp", headers=headers) as (read_stream, write_stream, _): async with ClientSession(read_stream, write_stream) as session: await session.initialize() # Call the tool that echoes context - tool_result = await session.call_tool( - "echo_context", {"request_id": f"request-{i}"} - ) + tool_result = await session.call_tool("echo_context", {"request_id": f"request-{i}"}) assert len(tool_result.content) == 1 assert isinstance(tool_result.content[0], TextContent) diff --git a/tests/shared/test_ws.py b/tests/shared/test_ws.py index 1381c8153..5081f1d53 100644 --- a/tests/shared/test_ws.py +++ b/tests/shared/test_ws.py @@ -54,11 +54,7 @@ async def handle_read_resource(uri: AnyUrl) -> str | bytes: await anyio.sleep(2.0) return f"Slow response from {uri.host}" - raise McpError( - error=ErrorData( - code=404, message="OOPS! no resource with that URI was found" - ) - ) + raise McpError(error=ErrorData(code=404, message="OOPS! no resource with that URI was found")) @self.list_tools() async def handle_list_tools() -> list[Tool]: @@ -81,12 +77,8 @@ def make_server_app() -> Starlette: server = ServerTest() async def handle_ws(websocket): - async with websocket_server( - websocket.scope, websocket.receive, websocket.send - ) as streams: - await server.run( - streams[0], streams[1], server.create_initialization_options() - ) + async with websocket_server(websocket.scope, websocket.receive, websocket.send) as streams: + await server.run(streams[0], streams[1], server.create_initialization_options()) app = Starlette( routes=[ @@ -99,11 +91,7 @@ async def handle_ws(websocket): def run_server(server_port: int) -> None: app = make_server_app() - server = uvicorn.Server( - config=uvicorn.Config( - app=app, host="127.0.0.1", port=server_port, log_level="error" - ) - ) + server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) print(f"starting server on {server_port}") server.run() @@ -115,9 +103,7 @@ def run_server(server_port: int) -> None: @pytest.fixture() def server(server_port: int) -> Generator[None, None, None]: - proc = multiprocessing.Process( - target=run_server, kwargs={"server_port": server_port}, daemon=True - ) + proc = multiprocessing.Process(target=run_server, kwargs={"server_port": server_port}, daemon=True) print("starting process") proc.start() @@ -147,9 +133,7 @@ def server(server_port: int) -> Generator[None, None, None]: @pytest.fixture() -async def initialized_ws_client_session( - server, server_url: str -) -> AsyncGenerator[ClientSession, None]: +async def initialized_ws_client_session(server, server_url: str) -> AsyncGenerator[ClientSession, None]: """Create and initialize a WebSocket client session""" async with websocket_client(server_url + "/ws") as streams: async with ClientSession(*streams) as session: @@ -186,9 +170,7 @@ async def test_ws_client_happy_request_and_response( initialized_ws_client_session: ClientSession, ) -> None: """Test a successful request and response via WebSocket""" - result = await initialized_ws_client_session.read_resource( - AnyUrl("foobar://example") - ) + result = await initialized_ws_client_session.read_resource(AnyUrl("foobar://example")) assert isinstance(result, ReadResourceResult) assert isinstance(result.contents, list) assert len(result.contents) > 0 @@ -218,9 +200,7 @@ async def test_ws_client_timeout( # Now test that we can still use the session after a timeout with anyio.fail_after(5): # Longer timeout to allow completion - result = await initialized_ws_client_session.read_resource( - AnyUrl("foobar://example") - ) + result = await initialized_ws_client_session.read_resource(AnyUrl("foobar://example")) assert isinstance(result, ReadResourceResult) assert isinstance(result.contents, list) assert len(result.contents) > 0 diff --git a/tests/test_examples.py b/tests/test_examples.py index b2fff1a91..230e7d394 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -31,9 +31,7 @@ async def test_complex_inputs(): async with client_session(mcp._mcp_server) as client: tank = {"shrimp": [{"name": "bob"}, {"name": "alice"}]} - result = await client.call_tool( - "name_shrimp", {"tank": tank, "extra_names": ["charlie"]} - ) + result = await client.call_tool("name_shrimp", {"tank": tank, "extra_names": ["charlie"]}) assert len(result.content) == 3 assert isinstance(result.content[0], TextContent) assert isinstance(result.content[1], TextContent) @@ -86,9 +84,7 @@ async def test_desktop(monkeypatch): def test_docs_examples(example: CodeExample, eval_example: EvalExample): ruff_ignore: list[str] = ["F841", "I001"] - eval_example.set_config( - ruff_ignore=ruff_ignore, target_version="py310", line_length=88 - ) + eval_example.set_config(ruff_ignore=ruff_ignore, target_version="py310", line_length=88) if eval_example.update_examples: # pragma: no cover eval_example.format(example) From bcd2734ef43df48f94f1a64c2de9777cd9bc7885 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Fri, 30 May 2025 12:34:15 +0200 Subject: [PATCH 2/2] revert pyright change --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 56a7d385a..9ad50ab58 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,10 +86,10 @@ Issues = "https://github.com/modelcontextprotocol/python-sdk/issues" packages = ["src/mcp"] [tool.pyright] -typeCheckingMode = "strict" include = ["src/mcp", "tests", "examples/servers"] venvPath = "." venv = ".venv" +strict = ["src/mcp/**/*.py"] [tool.ruff.lint] select = ["C4", "E", "F", "I", "PERF", "UP"]