Skip to content

Commit 59e3a2f

Browse files
committed
feat(idp): support prompt=none
1 parent 37fcac9 commit 59e3a2f

File tree

3 files changed

+62
-3
lines changed

3 files changed

+62
-3
lines changed

allauth/idp/oidc/internal/oauthlib/request_validator.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,17 @@ def rotate_refresh_token(self, request):
377377
return app_settings.ROTATE_REFRESH_TOKEN
378378

379379
def validate_silent_login(self, request) -> bool:
380+
if context.request.user.is_authenticated:
381+
request.user = context.request.user
382+
return True
380383
return False
381384

382385
def validate_silent_authorization(self, request) -> bool:
383-
return False
386+
granted_scopes = set()
387+
tokens = Token.objects.valid().filter(
388+
user=context.request.user,
389+
type__in=[Token.Type.REFRESH_TOKEN, Token.Type.ACCESS_TOKEN],
390+
)
391+
for token in tokens.iterator():
392+
granted_scopes.update(token.get_scopes())
393+
return set(request.scopes).issubset(granted_scopes)

allauth/idp/oidc/tests/test_authorization.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,3 +400,43 @@ def test_redirect_to_login_with_prompt_login(
400400
assert params["prompt"][0] == next_prompt
401401
else:
402402
assert "prompt" not in params
403+
404+
405+
@pytest.mark.parametrize(
406+
"client_fixture,scope,error",
407+
[
408+
("auth_client", "openid", None),
409+
("auth_client", "openid profile", "consent_required"),
410+
("client", "openid", "login_required"),
411+
],
412+
)
413+
def test_prompt_none(
414+
request,
415+
client_fixture,
416+
scope,
417+
error,
418+
oidc_client,
419+
user,
420+
access_token_generator,
421+
):
422+
access_token_generator(oidc_client, user, scopes=["openid"])
423+
client = request.getfixturevalue(client_fixture)
424+
redirect_uri = oidc_client.get_redirect_uris()[0]
425+
resp = client.get(
426+
reverse("idp:oidc:authorization")
427+
+ "?"
428+
+ urlencode(
429+
{
430+
"client_id": oidc_client.id,
431+
"response_type": "code",
432+
"scope": scope,
433+
"redirect_uri": redirect_uri,
434+
"prompt": "none",
435+
}
436+
)
437+
)
438+
assert resp.status_code == HTTPStatus.FOUND
439+
if error:
440+
assert resp["location"] == "https://client/callback?error=" + error
441+
else:
442+
assert resp["location"].startswith("https://client/callback?code=")

allauth/idp/oidc/views.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,16 @@ def get(self, request, *args, **kwargs):
8888
return response
8989
orequest = extract_params(self.request)
9090
try:
91-
self._scopes, self._request_info = (
92-
get_server().validate_authorization_request(*orequest)
91+
server = get_server()
92+
self._scopes, self._request_info = server.validate_authorization_request(
93+
*orequest
9394
)
95+
if "none" in self._request_info.get("prompt", ()):
96+
oresponse = server.create_authorization_response(
97+
*orequest, scopes=self._scopes
98+
)
99+
return convert_response(*oresponse)
100+
94101
# Errors that should be shown to the user on the provider website
95102
except errors.FatalClientError as e:
96103
return respond_html_error(request, e)
@@ -134,6 +141,8 @@ def _login_required(self, request) -> Optional[HttpResponse]:
134141
prompts = prompt.split()
135142
if "login" in prompts:
136143
return self._handle_login_prompt(request, prompts)
144+
if "none" in prompts:
145+
return None
137146
if request.user.is_authenticated:
138147
return None
139148
return login_required()(None)(request) # type:ignore[misc,type-var]

0 commit comments

Comments
 (0)