diff --git a/fastapi_sso/sso/base.py b/fastapi_sso/sso/base.py index 1d325a7..8a41d70 100644 --- a/fastapi_sso/sso/base.py +++ b/fastapi_sso/sso/base.py @@ -20,6 +20,26 @@ logger = logging.getLogger(__name__) +class HttpxClientKwargsType(TypedDict, total=False): + """Parameters of :class:`httpx.AsyncClient`""" + verify: bool | str + """SSL certificates (a.k.a CA bundle) used to verify the identity of + requested hosts. Either `True` (default CA bundle), a path to an SSL + certificate file, an `ssl.SSLContext`, or `False` (which will disable + verification).""" + cert: str | tuple[str, str] | tuple[str, str, str] + """An SSL certificate used by the requested host to authenticate the + client. Either a path to an SSL certificate file, or two-tuple of + (certificate file, key file), or a three-tuple of (certificate file, key + file, password).""" + proxy: str + """A proxy URL where all the traffic should be routed.""" + proxies: str + """A dictionary mapping HTTP protocols to proxy URLs.""" + timeout: int + """The timeout configuration to use when sending requests.""" + + class DiscoveryDocument(TypedDict): """Discovery document.""" @@ -291,10 +311,11 @@ async def verify_and_process( self, request: Request, *, - params: Optional[Dict[str, Any]] = None, - headers: Optional[Dict[str, Any]] = None, - redirect_uri: Optional[str] = None, - convert_response: Literal[True] = True, + params: Optional[Dict[str, Any]], + headers: Optional[Dict[str, Any]], + redirect_uri: Optional[str], + convert_response: Literal[True], + httpx_client_kwargs: Optional[HttpxClientKwargsType], ) -> Optional[OpenID]: ... @overload @@ -302,10 +323,11 @@ async def verify_and_process( self, request: Request, *, - params: Optional[Dict[str, Any]] = None, - headers: Optional[Dict[str, Any]] = None, - redirect_uri: Optional[str] = None, + params: Optional[Dict[str, Any]], + headers: Optional[Dict[str, Any]], + redirect_uri: Optional[str], convert_response: Literal[False], + httpx_client_kwargs: Optional[HttpxClientKwargsType], ) -> Optional[Dict[str, Any]]: ... async def verify_and_process( @@ -315,7 +337,8 @@ async def verify_and_process( params: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, Any]] = None, redirect_uri: Optional[str] = None, - convert_response: Union[Literal[True], Literal[False]] = True, + convert_response: bool = True, + httpx_client_kwargs: Optional[HttpxClientKwargsType] = None ) -> Union[Optional[OpenID], Optional[Dict[str, Any]]]: """Processes the login given a FastAPI (Starlette) Request object. This should be used for the /callback path. @@ -325,6 +348,7 @@ async def verify_and_process( headers (Optional[Dict[str, Any]]): Additional headers to pass to the provider. redirect_uri (Optional[str]): Overrides the `redirect_uri` specified on this instance. convert_response (bool): If True, userinfo response is converted to OpenID object. + httpx_client_kwargs (HttpxClientKwargsType): Extra keyword-arguments passed to :class:`httpx.AsyncClient`. Raises: SSOLoginError: If the 'code' parameter is not found in the callback request. @@ -334,7 +358,7 @@ async def verify_and_process( Optional[Dict[str, Any]]: The original JSON response from the API. """ headers = headers or {} - code = request.query_params.get("code") + code: Optional[str] = request.query_params.get("code") if code is None: logger.debug( "Callback request:\n\tURI: %s\n\tHeaders: %s\n\tQuery params: %s", @@ -359,6 +383,7 @@ async def verify_and_process( redirect_uri=redirect_uri, pkce_code_verifier=pkce_code_verifier, convert_response=convert_response, + httpx_client_kwargs=httpx_client_kwargs, ) def __enter__(self) -> "SSOBase": @@ -390,11 +415,12 @@ async def process_login( code: str, request: Request, *, - params: Optional[Dict[str, Any]] = None, - additional_headers: Optional[Dict[str, Any]] = None, - redirect_uri: Optional[str] = None, - pkce_code_verifier: Optional[str] = None, - convert_response: Literal[True] = True, + params: Optional[Dict[str, Any]], + additional_headers: Optional[Dict[str, Any]], + redirect_uri: Optional[str], + pkce_code_verifier: Optional[str], + convert_response: Literal[True], + httpx_client_kwargs: Optional[HttpxClientKwargsType], ) -> Optional[OpenID]: ... @overload @@ -403,13 +429,28 @@ async def process_login( code: str, request: Request, *, - params: Optional[Dict[str, Any]] = None, - additional_headers: Optional[Dict[str, Any]] = None, - redirect_uri: Optional[str] = None, - pkce_code_verifier: Optional[str] = None, + params: Optional[Dict[str, Any]], + additional_headers: Optional[Dict[str, Any]], + redirect_uri: Optional[str], + pkce_code_verifier: Optional[str], convert_response: Literal[False], + httpx_client_kwargs: Optional[HttpxClientKwargsType], ) -> Optional[Dict[str, Any]]: ... + @overload + async def process_login( + self, + code: str, + request: Request, + *, + params: Optional[Dict[str, Any]], + additional_headers: Optional[Dict[str, Any]], + redirect_uri: Optional[str], + pkce_code_verifier: Optional[str], + convert_response: bool, + httpx_client_kwargs: Optional[HttpxClientKwargsType], + ) -> Union[Optional[OpenID], Optional[Dict[str, Any]]]: ... + async def process_login( self, code: str, @@ -419,7 +460,8 @@ async def process_login( additional_headers: Optional[Dict[str, Any]] = None, redirect_uri: Optional[str] = None, pkce_code_verifier: Optional[str] = None, - convert_response: Union[Literal[True], Literal[False]] = True, + convert_response: bool = True, + httpx_client_kwargs: Optional[HttpxClientKwargsType] = None, ) -> Union[Optional[OpenID], Optional[Dict[str, Any]]]: """Processes login from the callback endpoint to verify the user and request user info endpoint. It's a lower-level method, typically, you should use `verify_and_process` instead. @@ -432,6 +474,7 @@ async def process_login( redirect_uri (Optional[str]): Overrides the `redirect_uri` specified on this instance. pkce_code_verifier (Optional[str]): A PKCE code verifier sent to the server to verify the login request. convert_response (bool): If True, userinfo response is converted to OpenID object. + httpx_client_kwargs (HttpxClientKwargsType): Extra keyword-arguments passed to :class:`httpx.AsyncClient`. Raises: ReusedOauthClientWarning: If the SSO object is reused, which is not safe and caused security issues. @@ -451,6 +494,7 @@ async def process_login( ), ReusedOauthClientWarning, ) + httpx_client_kwargs = httpx_client_kwargs or {} params = params or {} params.update(self._extra_query_params) additional_headers = additional_headers or {} @@ -483,7 +527,7 @@ async def process_login( auth = httpx.BasicAuth(self.client_id, self.client_secret) - async with httpx.AsyncClient() as session: + async with httpx.AsyncClient(**httpx_client_kwargs) as session: response = await session.post(token_url, headers=headers, content=body, auth=auth) content = response.json() self._refresh_token = content.get("refresh_token")