diff --git a/diracx-core/src/diracx/core/settings.py b/diracx-core/src/diracx/core/settings.py index dcc36e610..d726b0986 100644 --- a/diracx-core/src/diracx/core/settings.py +++ b/diracx-core/src/diracx/core/settings.py @@ -16,7 +16,7 @@ if TYPE_CHECKING: from pydantic.config import BaseConfig from pydantic.fields import ModelField - +from cryptography.fernet import Fernet T = TypeVar("T") @@ -42,6 +42,14 @@ def validate(cls, value: Any) -> SecretStr: return super().validate(value) +class FernetKey(SecretStr): + fernet: Fernet + + def __init__(self, data: str): + super().__init__(data) + self.fernet = Fernet(self.get_secret_value()) + + class LocalFileUrl(AnyUrl): host_required = False allowed_schemes = {"file"} diff --git a/diracx-routers/src/diracx/routers/auth.py b/diracx-routers/src/diracx/routers/auth.py index 18a2d4464..88c5c5872 100644 --- a/diracx-routers/src/diracx/routers/auth.py +++ b/diracx-routers/src/diracx/routers/auth.py @@ -16,6 +16,7 @@ from authlib.jose import JoseError, JsonWebKey, JsonWebToken from authlib.oidc.core import IDToken from cachetools import TTLCache +from cryptography.fernet import Fernet from fastapi import ( Depends, Form, @@ -41,7 +42,7 @@ SecurityProperty, UnevaluatedProperty, ) -from diracx.core.settings import ServiceSettingsBase, TokenSigningKey +from diracx.core.settings import FernetKey, ServiceSettingsBase, TokenSigningKey from diracx.db.sql.auth.schema import FlowStatus, RefreshTokenStatus from .dependencies import ( @@ -64,6 +65,9 @@ class AuthSettings(ServiceSettingsBase, env_prefix="DIRACX_SERVICE_AUTH_"): device_flow_expiration_seconds: int = 600 authorization_flow_expiration_seconds: int = 300 + # State key is used to encrypt/decrypt the state dict passed to the IAM + state_key: FernetKey + token_issuer: str = "http://lhcbdirac.cern.ch/" token_audience: str = "dirac" token_key: TokenSigningKey @@ -387,7 +391,7 @@ async def initiate_device_flow( async def initiate_authorization_flow_with_iam( - config, vo: str, redirect_uri: str, state: dict[str, str] + config, vo: str, redirect_uri: str, state: dict[str, str], cipher_suite: Fernet ): # code_verifier: https://www.rfc-editor.org/rfc/rfc7636#section-4.1 code_verifier = secrets.token_hex() @@ -406,10 +410,9 @@ async def initiate_authorization_flow_with_iam( # Take these two from CS/.well-known authorization_endpoint = server_metadata["authorization_endpoint"] - # TODO : encrypt it for good - encrypted_state = base64.urlsafe_b64encode( - json.dumps(state | {"vo": vo, "code_verifier": code_verifier}).encode() - ).decode() + encrypted_state = encrypt_state( + state | {"vo": vo, "code_verifier": code_verifier}, cipher_suite + ) urlParams = [ "response_type=code", @@ -503,15 +506,28 @@ async def do_device_flow( } authorization_flow_url = await initiate_authorization_flow_with_iam( - config, parsed_scope["vo"], redirect_uri, state_for_iam + config, + parsed_scope["vo"], + redirect_uri, + state_for_iam, + settings.state_key.fernet, ) return RedirectResponse(authorization_flow_url) -def decrypt_state(state): +def encrypt_state(state_dict: dict[str, str], cipher_suite: Fernet) -> str: + """Encrypt the state dict and return it as a string""" + return cipher_suite.encrypt( + base64.urlsafe_b64encode(json.dumps(state_dict).encode()) + ).decode() + + +def decrypt_state(state: str, cipher_suite: Fernet) -> dict[str, str]: + """Decrypt the state string and return it as a dict""" try: - # TODO: There have been better schemes like rot13 - return json.loads(base64.urlsafe_b64decode(state).decode()) + return json.loads( + base64.urlsafe_b64decode(cipher_suite.decrypt(state.encode())).decode() + ) except Exception as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid state" @@ -534,7 +550,7 @@ async def finish_device_flow( can map it to the corresponding device flow using the user_code in the cookie/session """ - decrypted_state = decrypt_state(state) + decrypted_state = decrypt_state(state, settings.state_key.fernet) assert decrypted_state["grant_type"] == GrantType.device_code id_token = await get_token_from_iam( @@ -983,6 +999,7 @@ async def authorization_flow( parsed_scope["vo"], f"{request.url.replace(query='')}/complete", state_for_iam, + settings.state_key.fernet, ) return responses.RedirectResponse(authorization_flow_url) @@ -997,7 +1014,7 @@ async def authorization_flow_complete( config: Config, settings: AuthSettings, ): - decrypted_state = decrypt_state(state) + decrypted_state = decrypt_state(state, settings.state_key.fernet) assert decrypted_state["grant_type"] == GrantType.authorization_code id_token = await get_token_from_iam( diff --git a/diracx-routers/tests/auth/test_standard.py b/diracx-routers/tests/auth/test_standard.py index 2f9679ee2..d282c4734 100644 --- a/diracx-routers/tests/auth/test_standard.py +++ b/diracx-routers/tests/auth/test_standard.py @@ -8,16 +8,21 @@ import httpx import jwt import pytest +from cryptography.fernet import Fernet from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import rsa +from fastapi import HTTPException from pytest_httpx import HTTPXMock from diracx.core.config import Config from diracx.core.properties import NORMAL_USER, PROXY_MANAGEMENT, SecurityProperty from diracx.routers.auth import ( AuthSettings, + GrantType, _server_metadata_cache, create_token, + decrypt_state, + encrypt_state, get_server_metadata, parse_and_validate_scope, ) @@ -370,6 +375,7 @@ async def test_refresh_token_invalid(test_client, auth_httpx_mock: HTTPXMock): new_auth_settings = AuthSettings( token_key=pem, + state_key=Fernet.generate_key(), allowed_redirects=[ "http://diracx.test.invalid:8000/api/docs/oauth2-redirect", ], @@ -680,3 +686,38 @@ def test_parse_scopes_invalid(vos, groups, scope, expected_error): available_properties = SecurityProperty.available_properties() with pytest.raises(ValueError, match=expected_error): parse_and_validate_scope(scope, config, available_properties) + + +def test_encrypt_decrypt_state_valid_state(fernet_key): + """Test that decrypt_state returns the correct state""" + fernet = Fernet(fernet_key) + # Create a valid state + state_dict = { + "vo": "lhcb", + "code_verifier": secrets.token_hex(), + "user_code": "AE19U", + "grant_type": GrantType.device_code.value, + } + + state = encrypt_state(state_dict, fernet) + result = decrypt_state(state, fernet) + + assert result == state_dict + + # Create an empty state + state_dict = {} + + state = encrypt_state(state_dict, fernet) + result = decrypt_state(state, fernet) + + assert result == state_dict + + +def test_encrypt_decrypt_state_invalid_state(fernet_key): + """Test that decrypt_state raises an error when the state is invalid""" + state = "invalid_state" # Invalid state string + + with pytest.raises(HTTPException) as exc_info: + decrypt_state(state, fernet_key) + assert exc_info.value.status_code == 400 + assert exc_info.value.detail == "Invalid state" diff --git a/diracx-testing/src/diracx/testing/__init__.py b/diracx-testing/src/diracx/testing/__init__.py index 99a044682..591b8741e 100644 --- a/diracx-testing/src/diracx/testing/__init__.py +++ b/diracx-testing/src/diracx/testing/__init__.py @@ -69,11 +69,19 @@ def rsa_private_key_pem() -> str: @pytest.fixture(scope="session") -def test_auth_settings(rsa_private_key_pem) -> AuthSettings: +def fernet_key() -> str: + from cryptography.fernet import Fernet + + return Fernet.generate_key().decode() + + +@pytest.fixture(scope="session") +def test_auth_settings(rsa_private_key_pem, fernet_key) -> AuthSettings: from diracx.routers.auth import AuthSettings yield AuthSettings( token_key=rsa_private_key_pem, + state_key=fernet_key, allowed_redirects=[ "http://diracx.test.invalid:8000/api/docs/oauth2-redirect", ],