Skip to content

Commit 37fcac9

Browse files
committed
feat(idp): handle prompt=login
1 parent b432f75 commit 37fcac9

File tree

3 files changed

+84
-4
lines changed

3 files changed

+84
-4
lines changed

allauth/core/internal/httpkit.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,25 @@ def redirect(to):
4848
return shortcuts.redirect(f"/{to}")
4949

5050

51+
def del_query_params(url: str, *params: str) -> str:
52+
parsed_url = urlparse(url)
53+
query_params = parse_qs(parsed_url.query, keep_blank_values=True)
54+
for param in params:
55+
query_params.pop(param, None)
56+
encoded_query = urlencode(query_params, doseq=True)
57+
new_url = urlunparse(
58+
(
59+
parsed_url.scheme,
60+
parsed_url.netloc,
61+
parsed_url.path,
62+
parsed_url.params,
63+
encoded_query,
64+
parsed_url.fragment,
65+
)
66+
)
67+
return new_url
68+
69+
5170
def add_query_params(url: str, params: dict) -> str:
5271
parsed_url = urlparse(url)
5372
query_params = parse_qs(parsed_url.query)

allauth/idp/oidc/tests/test_authorization.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,3 +366,37 @@ def test_authorization_code_flow_with_pkce(
366366
"scope",
367367
"refresh_token",
368368
}
369+
370+
371+
@pytest.mark.parametrize("client_fixture", ["auth_client", "client"])
372+
@pytest.mark.parametrize(
373+
"prompt,next_prompt", [("login", None), ("login consent", "consent")]
374+
)
375+
def test_redirect_to_login_with_prompt_login(
376+
request, client_fixture, oidc_client, prompt, next_prompt
377+
):
378+
client = request.getfixturevalue(client_fixture)
379+
redirect_uri = oidc_client.get_redirect_uris()[0]
380+
resp = client.get(
381+
reverse("idp:oidc:authorization")
382+
+ "?"
383+
+ urlencode(
384+
{
385+
"client_id": oidc_client.id,
386+
"response_type": "code",
387+
"redirect_uri": redirect_uri,
388+
"prompt": prompt,
389+
}
390+
)
391+
)
392+
assert resp.status_code == HTTPStatus.FOUND
393+
parts = urlparse(resp["location"])
394+
assert parts.path == reverse(
395+
"account_login" if client_fixture == "client" else "account_reauthenticate"
396+
)
397+
params = parse_qs(parts.query)
398+
params = parse_qs(urlparse(params["next"][0]).query)
399+
if next_prompt:
400+
assert params["prompt"][0] == next_prompt
401+
else:
402+
assert "prompt" not in params

allauth/idp/oidc/views.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1-
from typing import List
1+
from typing import List, Optional
22

3+
from django.contrib.auth import REDIRECT_FIELD_NAME
34
from django.contrib.auth.decorators import login_required
45
from django.contrib.sites.shortcuts import get_current_site
56
from django.core.exceptions import PermissionDenied
67
from django.core.signing import BadSignature, Signer
78
from django.http import (
9+
HttpRequest,
10+
HttpResponse,
811
HttpResponseForbidden,
912
HttpResponseRedirect,
1013
JsonResponse,
@@ -23,7 +26,7 @@
2326
from allauth.account import app_settings as account_settings
2427
from allauth.account.internal.decorators import login_not_required
2528
from allauth.core.internal import jwkkit
26-
from allauth.core.internal.httpkit import add_query_params
29+
from allauth.core.internal.httpkit import add_query_params, del_query_params
2730
from allauth.idp.oidc import app_settings
2831
from allauth.idp.oidc.adapter import get_adapter
2932
from allauth.idp.oidc.forms import AuthorizationForm
@@ -124,10 +127,34 @@ def post(self, request, *args, **kwargs):
124127
return self._respond_with_access_denied()
125128
return super().post(request, *args, **kwargs)
126129

127-
def _login_required(self, request):
130+
def _login_required(self, request) -> Optional[HttpResponse]:
131+
prompts = []
132+
prompt = request.GET.get("prompt")
133+
if prompt:
134+
prompts = prompt.split()
135+
if "login" in prompts:
136+
return self._handle_login_prompt(request, prompts)
128137
if request.user.is_authenticated:
129138
return None
130-
return login_required()(None)(request)
139+
return login_required()(None)(request) # type:ignore[misc,type-var]
140+
141+
def _handle_login_prompt(
142+
self, request: HttpRequest, prompts: List[str]
143+
) -> HttpResponse:
144+
prompts.remove("login")
145+
next_url = request.get_full_path()
146+
if prompts:
147+
next_url = add_query_params(next_url, {"prompt": " ".join(prompts)})
148+
else:
149+
next_url = del_query_params(next_url, "prompt")
150+
params = {}
151+
params[REDIRECT_FIELD_NAME] = next_url
152+
path = reverse(
153+
"account_reauthenticate"
154+
if request.user.is_authenticated
155+
else "account_login"
156+
)
157+
return HttpResponseRedirect(add_query_params(path, params))
131158

132159
def _skip_consent(self):
133160
scopes = self._request_info["request"].scopes

0 commit comments

Comments
 (0)