Skip to content

Commit

Permalink
httpx.AsyncClient factory method to customize client
Browse files Browse the repository at this point in the history
Argument `get_httpx_client` incorporated to `SSOBase` to allow
customization of `httpx.AsyncClient` used to call auth provider
  • Loading branch information
santibreo committed Sep 7, 2024
1 parent a85929e commit 29c1045
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 36 deletions.
37 changes: 4 additions & 33 deletions fastapi_sso/sso/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
import warnings
from types import TracebackType
from typing import Any, ClassVar, Dict, List, Literal, Optional, Type, TypedDict, Union, overload
from typing import Any, Callable, ClassVar, Dict, List, Literal, Optional, Type, TypedDict, Union, overload

import httpx
import pydantic
Expand All @@ -20,26 +20,6 @@
logger = logging.getLogger(__name__)


class HttpxClientKwargsType(TypedDict, total=False):
"""Parameters of :class:`httpx.AsyncClient`"""
verify: bool | str
"""SSL certificates (a.k.a CA bundle) used to verify the identity of
requested hosts. Either `True` (default CA bundle), a path to an SSL
certificate file, an `ssl.SSLContext`, or `False` (which will disable
verification)."""
cert: str | tuple[str, str] | tuple[str, str, str]
"""An SSL certificate used by the requested host to authenticate the
client. Either a path to an SSL certificate file, or two-tuple of
(certificate file, key file), or a three-tuple of (certificate file, key
file, password)."""
proxy: str
"""A proxy URL where all the traffic should be routed."""
proxies: str
"""A dictionary mapping HTTP protocols to proxy URLs."""
timeout: int
"""The timeout configuration to use when sending requests."""


class DiscoveryDocument(TypedDict):
"""Discovery document."""

Expand Down Expand Up @@ -97,12 +77,14 @@ def __init__(
allow_insecure_http: bool = False,
use_state: bool = False,
scope: Optional[List[str]] = None,
get_async_client: Optional[Callable[[], httpx.AsyncClient]] = None,
):
"""Base class (mixin) for all SSO providers."""
self.client_id: str = client_id
self.client_secret: str = client_secret
self.redirect_uri: Optional[Union[pydantic.AnyHttpUrl, str]] = redirect_uri
self.allow_insecure_http: bool = allow_insecure_http
self.get_async_client: Callable[[], httpx.AsyncClient] = get_async_client or httpx.AsyncClient
self._oauth_client: Optional[WebApplicationClient] = None
self._generated_state: Optional[str] = None

Expand Down Expand Up @@ -315,7 +297,6 @@ async def verify_and_process(
headers: Optional[Dict[str, Any]],
redirect_uri: Optional[str],
convert_response: Literal[True],
httpx_client_kwargs: Optional[HttpxClientKwargsType],
) -> Optional[OpenID]: ...

@overload
Expand All @@ -327,7 +308,6 @@ async def verify_and_process(
headers: Optional[Dict[str, Any]],
redirect_uri: Optional[str],
convert_response: Literal[False],
httpx_client_kwargs: Optional[HttpxClientKwargsType],
) -> Optional[Dict[str, Any]]: ...

async def verify_and_process(
Expand All @@ -338,7 +318,6 @@ async def verify_and_process(
headers: Optional[Dict[str, Any]] = None,
redirect_uri: Optional[str] = None,
convert_response: bool = True,
httpx_client_kwargs: Optional[HttpxClientKwargsType] = None
) -> Union[Optional[OpenID], Optional[Dict[str, Any]]]:
"""Processes the login given a FastAPI (Starlette) Request object. This should be used for the /callback path.
Expand All @@ -348,7 +327,6 @@ async def verify_and_process(
headers (Optional[Dict[str, Any]]): Additional headers to pass to the provider.
redirect_uri (Optional[str]): Overrides the `redirect_uri` specified on this instance.
convert_response (bool): If True, userinfo response is converted to OpenID object.
httpx_client_kwargs (HttpxClientKwargsType): Extra keyword-arguments passed to :class:`httpx.AsyncClient`.
Raises:
SSOLoginError: If the 'code' parameter is not found in the callback request.
Expand Down Expand Up @@ -383,7 +361,6 @@ async def verify_and_process(
redirect_uri=redirect_uri,
pkce_code_verifier=pkce_code_verifier,
convert_response=convert_response,
httpx_client_kwargs=httpx_client_kwargs,
)

def __enter__(self) -> "SSOBase":
Expand Down Expand Up @@ -420,7 +397,6 @@ async def process_login(
redirect_uri: Optional[str],
pkce_code_verifier: Optional[str],
convert_response: Literal[True],
httpx_client_kwargs: Optional[HttpxClientKwargsType],
) -> Optional[OpenID]: ...

@overload
Expand All @@ -434,7 +410,6 @@ async def process_login(
redirect_uri: Optional[str],
pkce_code_verifier: Optional[str],
convert_response: Literal[False],
httpx_client_kwargs: Optional[HttpxClientKwargsType],
) -> Optional[Dict[str, Any]]: ...

@overload
Expand All @@ -448,7 +423,6 @@ async def process_login(
redirect_uri: Optional[str],
pkce_code_verifier: Optional[str],
convert_response: bool,
httpx_client_kwargs: Optional[HttpxClientKwargsType],
) -> Union[Optional[OpenID], Optional[Dict[str, Any]]]: ...

async def process_login(
Expand All @@ -461,7 +435,6 @@ async def process_login(
redirect_uri: Optional[str] = None,
pkce_code_verifier: Optional[str] = None,
convert_response: bool = True,
httpx_client_kwargs: Optional[HttpxClientKwargsType] = None,
) -> Union[Optional[OpenID], Optional[Dict[str, Any]]]:
"""Processes login from the callback endpoint to verify the user and request user info endpoint.
It's a lower-level method, typically, you should use `verify_and_process` instead.
Expand All @@ -474,7 +447,6 @@ async def process_login(
redirect_uri (Optional[str]): Overrides the `redirect_uri` specified on this instance.
pkce_code_verifier (Optional[str]): A PKCE code verifier sent to the server to verify the login request.
convert_response (bool): If True, userinfo response is converted to OpenID object.
httpx_client_kwargs (HttpxClientKwargsType): Extra keyword-arguments passed to :class:`httpx.AsyncClient`.
Raises:
ReusedOauthClientWarning: If the SSO object is reused, which is not safe and caused security issues.
Expand All @@ -494,7 +466,6 @@ async def process_login(
),
ReusedOauthClientWarning,
)
httpx_client_kwargs = httpx_client_kwargs or {}
params = params or {}
params.update(self._extra_query_params)
additional_headers = additional_headers or {}
Expand Down Expand Up @@ -527,7 +498,7 @@ async def process_login(

auth = httpx.BasicAuth(self.client_id, self.client_secret)

async with httpx.AsyncClient(**httpx_client_kwargs) as session:
async with self.get_async_client() as session:
response = await session.post(token_url, headers=headers, content=body, auth=auth)
content = response.json()
self._refresh_token = content.get("refresh_token")
Expand Down
5 changes: 2 additions & 3 deletions tests/test_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,7 @@ async def test_login_url_scope_additional(self, Provider: Type[SSOBase]):
async def test_process_login(self, Provider: Type[SSOBase], monkeypatch: pytest.MonkeyPatch):
sso = Provider("client_id", "client_secret")
FakeAsyncClient = make_fake_async_client(
returns_post=Response(url="https://localhost", json_content={"access_token": "token"}),
returns_get=Response(
returns_post=Response(url="https://localhost", json_content={"access_token": "token"}), returns_get=Response(
url="https://localhost",
json_content=AnythingDict(
{"token_endpoint": "https://localhost", "userinfo_endpoint": "https://localhost"}
Expand All @@ -151,7 +150,7 @@ async def fake_openid_from_response(_, __):
return OpenID(id="test", email="email@example.com", display_name="Test")

with sso:
monkeypatch.setattr("httpx.AsyncClient", FakeAsyncClient)
monkeypatch.setattr(sso, "get_async_client", FakeAsyncClient)
monkeypatch.setattr(sso, "openid_from_response", fake_openid_from_response)
request = Request(url="https://localhost?code=code&state=unique")
await sso.process_login("code", request)
Expand Down

0 comments on commit 29c1045

Please sign in to comment.