Skip to content

Commit

Permalink
Added httpx_client_kwargs parameter to allow customization of httpx.A…
Browse files Browse the repository at this point in the history
…syncClient behaviour
  • Loading branch information
santibreo committed Jul 30, 2024
1 parent 4a17c26 commit 495f870
Showing 1 changed file with 64 additions and 20 deletions.
84 changes: 64 additions & 20 deletions fastapi_sso/sso/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,26 @@
logger = logging.getLogger(__name__)


class HttpxClientKwargsType(TypedDict, total=False):
"""Parameters of :class:`httpx.AsyncClient`"""
verify: bool | str
"""SSL certificates (a.k.a CA bundle) used to verify the identity of
requested hosts. Either `True` (default CA bundle), a path to an SSL
certificate file, an `ssl.SSLContext`, or `False` (which will disable
verification)."""
cert: str | tuple[str, str] | tuple[str, str, str]
"""An SSL certificate used by the requested host to authenticate the
client. Either a path to an SSL certificate file, or two-tuple of
(certificate file, key file), or a three-tuple of (certificate file, key
file, password)."""
proxy: str
"""A proxy URL where all the traffic should be routed."""
proxies: str
"""A dictionary mapping HTTP protocols to proxy URLs."""
timeout: int
"""The timeout configuration to use when sending requests."""


class DiscoveryDocument(TypedDict):
"""Discovery document."""

Expand Down Expand Up @@ -291,21 +311,23 @@ async def verify_and_process(
self,
request: Request,
*,
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, Any]] = None,
redirect_uri: Optional[str] = None,
convert_response: Literal[True] = True,
params: Optional[Dict[str, Any]],
headers: Optional[Dict[str, Any]],
redirect_uri: Optional[str],
convert_response: Literal[True],
httpx_client_kwargs: Optional[HttpxClientKwargsType],
) -> Optional[OpenID]: ...

@overload
async def verify_and_process(
self,
request: Request,
*,
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, Any]] = None,
redirect_uri: Optional[str] = None,
params: Optional[Dict[str, Any]],
headers: Optional[Dict[str, Any]],
redirect_uri: Optional[str],
convert_response: Literal[False],
httpx_client_kwargs: Optional[HttpxClientKwargsType],
) -> Optional[Dict[str, Any]]: ...

async def verify_and_process(
Expand All @@ -315,7 +337,8 @@ async def verify_and_process(
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, Any]] = None,
redirect_uri: Optional[str] = None,
convert_response: Union[Literal[True], Literal[False]] = True,
convert_response: bool = True,
httpx_client_kwargs: Optional[HttpxClientKwargsType] = None
) -> Union[Optional[OpenID], Optional[Dict[str, Any]]]:
"""Processes the login given a FastAPI (Starlette) Request object. This should be used for the /callback path.
Expand All @@ -325,6 +348,7 @@ async def verify_and_process(
headers (Optional[Dict[str, Any]]): Additional headers to pass to the provider.
redirect_uri (Optional[str]): Overrides the `redirect_uri` specified on this instance.
convert_response (bool): If True, userinfo response is converted to OpenID object.
httpx_client_kwargs (HttpxClientKwargsType): Extra keyword-arguments passed to :class:`httpx.AsyncClient`.
Raises:
SSOLoginError: If the 'code' parameter is not found in the callback request.
Expand All @@ -334,7 +358,7 @@ async def verify_and_process(
Optional[Dict[str, Any]]: The original JSON response from the API.
"""
headers = headers or {}
code = request.query_params.get("code")
code: Optional[str] = request.query_params.get("code")
if code is None:
logger.debug(
"Callback request:\n\tURI: %s\n\tHeaders: %s\n\tQuery params: %s",
Expand All @@ -359,6 +383,7 @@ async def verify_and_process(
redirect_uri=redirect_uri,
pkce_code_verifier=pkce_code_verifier,
convert_response=convert_response,
httpx_client_kwargs=httpx_client_kwargs,
)

def __enter__(self) -> "SSOBase":
Expand Down Expand Up @@ -390,11 +415,12 @@ async def process_login(
code: str,
request: Request,
*,
params: Optional[Dict[str, Any]] = None,
additional_headers: Optional[Dict[str, Any]] = None,
redirect_uri: Optional[str] = None,
pkce_code_verifier: Optional[str] = None,
convert_response: Literal[True] = True,
params: Optional[Dict[str, Any]],
additional_headers: Optional[Dict[str, Any]],
redirect_uri: Optional[str],
pkce_code_verifier: Optional[str],
convert_response: Literal[True],
httpx_client_kwargs: Optional[HttpxClientKwargsType],
) -> Optional[OpenID]: ...

@overload
Expand All @@ -403,13 +429,28 @@ async def process_login(
code: str,
request: Request,
*,
params: Optional[Dict[str, Any]] = None,
additional_headers: Optional[Dict[str, Any]] = None,
redirect_uri: Optional[str] = None,
pkce_code_verifier: Optional[str] = None,
params: Optional[Dict[str, Any]],
additional_headers: Optional[Dict[str, Any]],
redirect_uri: Optional[str],
pkce_code_verifier: Optional[str],
convert_response: Literal[False],
httpx_client_kwargs: Optional[HttpxClientKwargsType],
) -> Optional[Dict[str, Any]]: ...

@overload
async def process_login(
self,
code: str,
request: Request,
*,
params: Optional[Dict[str, Any]],
additional_headers: Optional[Dict[str, Any]],
redirect_uri: Optional[str],
pkce_code_verifier: Optional[str],
convert_response: bool,
httpx_client_kwargs: Optional[HttpxClientKwargsType],
) -> Union[Optional[OpenID], Optional[Dict[str, Any]]]: ...

async def process_login(
self,
code: str,
Expand All @@ -419,7 +460,8 @@ async def process_login(
additional_headers: Optional[Dict[str, Any]] = None,
redirect_uri: Optional[str] = None,
pkce_code_verifier: Optional[str] = None,
convert_response: Union[Literal[True], Literal[False]] = True,
convert_response: bool = True,
httpx_client_kwargs: Optional[HttpxClientKwargsType] = None,
) -> Union[Optional[OpenID], Optional[Dict[str, Any]]]:
"""Processes login from the callback endpoint to verify the user and request user info endpoint.
It's a lower-level method, typically, you should use `verify_and_process` instead.
Expand All @@ -432,6 +474,7 @@ async def process_login(
redirect_uri (Optional[str]): Overrides the `redirect_uri` specified on this instance.
pkce_code_verifier (Optional[str]): A PKCE code verifier sent to the server to verify the login request.
convert_response (bool): If True, userinfo response is converted to OpenID object.
httpx_client_kwargs (HttpxClientKwargsType): Extra keyword-arguments passed to :class:`httpx.AsyncClient`.
Raises:
ReusedOauthClientWarning: If the SSO object is reused, which is not safe and caused security issues.
Expand All @@ -451,6 +494,7 @@ async def process_login(
),
ReusedOauthClientWarning,
)
httpx_client_kwargs = httpx_client_kwargs or {}
params = params or {}
params.update(self._extra_query_params)
additional_headers = additional_headers or {}
Expand Down Expand Up @@ -483,7 +527,7 @@ async def process_login(

auth = httpx.BasicAuth(self.client_id, self.client_secret)

async with httpx.AsyncClient() as session:
async with httpx.AsyncClient(**httpx_client_kwargs) as session:
response = await session.post(token_url, headers=headers, content=body, auth=auth)
content = response.json()
self._refresh_token = content.get("refresh_token")
Expand Down

0 comments on commit 495f870

Please sign in to comment.