diff --git a/fastapi_sso/sso/base.py b/fastapi_sso/sso/base.py index 8a41d70..36080fd 100644 --- a/fastapi_sso/sso/base.py +++ b/fastapi_sso/sso/base.py @@ -5,7 +5,7 @@ import os import warnings from types import TracebackType -from typing import Any, ClassVar, Dict, List, Literal, Optional, Type, TypedDict, Union, overload +from typing import Any, Callable, ClassVar, Dict, List, Literal, Optional, Type, TypedDict, Union, overload import httpx import pydantic @@ -20,26 +20,6 @@ logger = logging.getLogger(__name__) -class HttpxClientKwargsType(TypedDict, total=False): - """Parameters of :class:`httpx.AsyncClient`""" - verify: bool | str - """SSL certificates (a.k.a CA bundle) used to verify the identity of - requested hosts. Either `True` (default CA bundle), a path to an SSL - certificate file, an `ssl.SSLContext`, or `False` (which will disable - verification).""" - cert: str | tuple[str, str] | tuple[str, str, str] - """An SSL certificate used by the requested host to authenticate the - client. Either a path to an SSL certificate file, or two-tuple of - (certificate file, key file), or a three-tuple of (certificate file, key - file, password).""" - proxy: str - """A proxy URL where all the traffic should be routed.""" - proxies: str - """A dictionary mapping HTTP protocols to proxy URLs.""" - timeout: int - """The timeout configuration to use when sending requests.""" - - class DiscoveryDocument(TypedDict): """Discovery document.""" @@ -97,12 +77,14 @@ def __init__( allow_insecure_http: bool = False, use_state: bool = False, scope: Optional[List[str]] = None, + get_async_client: Optional[Callable[[], httpx.AsyncClient]] = None, ): """Base class (mixin) for all SSO providers.""" self.client_id: str = client_id self.client_secret: str = client_secret self.redirect_uri: Optional[Union[pydantic.AnyHttpUrl, str]] = redirect_uri self.allow_insecure_http: bool = allow_insecure_http + self.get_async_client: Callable[[], httpx.AsyncClient] = get_async_client or httpx.AsyncClient self._oauth_client: Optional[WebApplicationClient] = None self._generated_state: Optional[str] = None @@ -315,7 +297,6 @@ async def verify_and_process( headers: Optional[Dict[str, Any]], redirect_uri: Optional[str], convert_response: Literal[True], - httpx_client_kwargs: Optional[HttpxClientKwargsType], ) -> Optional[OpenID]: ... @overload @@ -327,7 +308,6 @@ async def verify_and_process( headers: Optional[Dict[str, Any]], redirect_uri: Optional[str], convert_response: Literal[False], - httpx_client_kwargs: Optional[HttpxClientKwargsType], ) -> Optional[Dict[str, Any]]: ... async def verify_and_process( @@ -338,7 +318,6 @@ async def verify_and_process( headers: Optional[Dict[str, Any]] = None, redirect_uri: Optional[str] = None, convert_response: bool = True, - httpx_client_kwargs: Optional[HttpxClientKwargsType] = None ) -> Union[Optional[OpenID], Optional[Dict[str, Any]]]: """Processes the login given a FastAPI (Starlette) Request object. This should be used for the /callback path. @@ -348,7 +327,6 @@ async def verify_and_process( headers (Optional[Dict[str, Any]]): Additional headers to pass to the provider. redirect_uri (Optional[str]): Overrides the `redirect_uri` specified on this instance. convert_response (bool): If True, userinfo response is converted to OpenID object. - httpx_client_kwargs (HttpxClientKwargsType): Extra keyword-arguments passed to :class:`httpx.AsyncClient`. Raises: SSOLoginError: If the 'code' parameter is not found in the callback request. @@ -383,7 +361,6 @@ async def verify_and_process( redirect_uri=redirect_uri, pkce_code_verifier=pkce_code_verifier, convert_response=convert_response, - httpx_client_kwargs=httpx_client_kwargs, ) def __enter__(self) -> "SSOBase": @@ -420,7 +397,6 @@ async def process_login( redirect_uri: Optional[str], pkce_code_verifier: Optional[str], convert_response: Literal[True], - httpx_client_kwargs: Optional[HttpxClientKwargsType], ) -> Optional[OpenID]: ... @overload @@ -434,7 +410,6 @@ async def process_login( redirect_uri: Optional[str], pkce_code_verifier: Optional[str], convert_response: Literal[False], - httpx_client_kwargs: Optional[HttpxClientKwargsType], ) -> Optional[Dict[str, Any]]: ... @overload @@ -448,7 +423,6 @@ async def process_login( redirect_uri: Optional[str], pkce_code_verifier: Optional[str], convert_response: bool, - httpx_client_kwargs: Optional[HttpxClientKwargsType], ) -> Union[Optional[OpenID], Optional[Dict[str, Any]]]: ... async def process_login( @@ -461,7 +435,6 @@ async def process_login( redirect_uri: Optional[str] = None, pkce_code_verifier: Optional[str] = None, convert_response: bool = True, - httpx_client_kwargs: Optional[HttpxClientKwargsType] = None, ) -> Union[Optional[OpenID], Optional[Dict[str, Any]]]: """Processes login from the callback endpoint to verify the user and request user info endpoint. It's a lower-level method, typically, you should use `verify_and_process` instead. @@ -474,7 +447,6 @@ async def process_login( redirect_uri (Optional[str]): Overrides the `redirect_uri` specified on this instance. pkce_code_verifier (Optional[str]): A PKCE code verifier sent to the server to verify the login request. convert_response (bool): If True, userinfo response is converted to OpenID object. - httpx_client_kwargs (HttpxClientKwargsType): Extra keyword-arguments passed to :class:`httpx.AsyncClient`. Raises: ReusedOauthClientWarning: If the SSO object is reused, which is not safe and caused security issues. @@ -494,7 +466,6 @@ async def process_login( ), ReusedOauthClientWarning, ) - httpx_client_kwargs = httpx_client_kwargs or {} params = params or {} params.update(self._extra_query_params) additional_headers = additional_headers or {} @@ -527,7 +498,7 @@ async def process_login( auth = httpx.BasicAuth(self.client_id, self.client_secret) - async with httpx.AsyncClient(**httpx_client_kwargs) as session: + async with self.get_async_client() as session: response = await session.post(token_url, headers=headers, content=body, auth=auth) content = response.json() self._refresh_token = content.get("refresh_token") diff --git a/tests/test_providers.py b/tests/test_providers.py index f53cf53..9a26aa9 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -138,8 +138,7 @@ async def test_login_url_scope_additional(self, Provider: Type[SSOBase]): async def test_process_login(self, Provider: Type[SSOBase], monkeypatch: pytest.MonkeyPatch): sso = Provider("client_id", "client_secret") FakeAsyncClient = make_fake_async_client( - returns_post=Response(url="https://localhost", json_content={"access_token": "token"}), - returns_get=Response( + returns_post=Response(url="https://localhost", json_content={"access_token": "token"}), returns_get=Response( url="https://localhost", json_content=AnythingDict( {"token_endpoint": "https://localhost", "userinfo_endpoint": "https://localhost"} @@ -151,7 +150,7 @@ async def fake_openid_from_response(_, __): return OpenID(id="test", email="email@example.com", display_name="Test") with sso: - monkeypatch.setattr("httpx.AsyncClient", FakeAsyncClient) + monkeypatch.setattr(sso, "get_async_client", FakeAsyncClient) monkeypatch.setattr(sso, "openid_from_response", fake_openid_from_response) request = Request(url="https://localhost?code=code&state=unique") await sso.process_login("code", request)