Skip to content

Commit

Permalink
refactor: remove audience claim
Browse files Browse the repository at this point in the history
  • Loading branch information
aldbr committed Mar 4, 2024
1 parent 3e75acd commit a1d4c6f
Show file tree
Hide file tree
Showing 12 changed files with 22 additions and 35 deletions.
1 change: 0 additions & 1 deletion diracx-cli/src/diracx/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ async def login(
async with DiracClient() as api:
data = await api.auth.initiate_device_flow(
client_id=api.client_id,
audience="Dirac server",
scope=" ".join(scopes),
)
print("Now go to:", data.verification_uri_complete)
Expand Down
4 changes: 0 additions & 4 deletions diracx-db/src/diracx/db/sql/auth/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ async def insert_device_flow(
self,
client_id: str,
scope: str,
audience: str,
) -> tuple[str, str]:
# Because the user_code might be short, there is a risk of conflicts
# This is why we retry multiple times
Expand All @@ -131,7 +130,6 @@ async def insert_device_flow(
stmt = insert(DeviceFlows).values(
client_id=client_id,
scope=scope,
audience=audience,
user_code=user_code,
device_code=hashed_device_code,
)
Expand All @@ -150,7 +148,6 @@ async def insert_authorization_flow(
self,
client_id: str,
scope: str,
audience: str,
code_challenge: str,
code_challenge_method: str,
redirect_uri: str,
Expand All @@ -161,7 +158,6 @@ async def insert_authorization_flow(
uuid=uuid,
client_id=client_id,
scope=scope,
audience=audience,
code_challenge=code_challenge,
code_challenge_method=code_challenge_method,
redirect_uri=redirect_uri,
Expand Down
2 changes: 0 additions & 2 deletions diracx-db/src/diracx/db/sql/auth/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ class DeviceFlows(Base):
creation_time = DateNowColumn()
client_id = Column(String(255))
scope = Column(String(1024))
audience = Column(String(255))
device_code = Column(String(128), unique=True) # Should be a hash
id_token = NullColumn(JSON())

Expand All @@ -57,7 +56,6 @@ class AuthorizationFlows(Base):
client_id = Column(String(255))
creation_time = DateNowColumn()
scope = Column(String(1024))
audience = Column(String(255))
code_challenge = Column(String(255))
code_challenge_method = Column(String(8))
redirect_uri = Column(String(255))
Expand Down
5 changes: 2 additions & 3 deletions diracx-db/tests/auth/test_authorization_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ async def test_insert_id_token(auth_db: AuthDB):
# First insert
async with auth_db as auth_db:
uuid = await auth_db.insert_authorization_flow(
"client_id", "scope", "audience", "code_challenge", "S256", "redirect_uri"
"client_id", "scope", "code_challenge", "S256", "redirect_uri"
)

id_token = {"sub": "myIdToken"}
Expand Down Expand Up @@ -68,12 +68,11 @@ async def test_insert(auth_db: AuthDB):
# First insert
async with auth_db as auth_db:
uuid1 = await auth_db.insert_authorization_flow(
"client_id", "scope", "audience", "code_challenge", "S256", "redirect_uri"
"client_id", "scope", "code_challenge", "S256", "redirect_uri"
)
uuid2 = await auth_db.insert_authorization_flow(
"client_id2",
"scope2",
"audience2",
"code_challenge2",
"S256",
"redirect_uri2",
Expand Down
17 changes: 11 additions & 6 deletions diracx-db/tests/auth/test_device_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,22 @@ async def test_device_user_code_collision(auth_db: AuthDB, monkeypatch):
# First insert should work
async with auth_db as auth_db:
code, device = await auth_db.insert_device_flow(
"client_id", "scope", "audience"
"client_id",
"scope",
)
assert code == "A" * USER_CODE_LENGTH
assert device

async with auth_db as auth_db:
with pytest.raises(NotImplementedError, match="insert new device flow"):
await auth_db.insert_device_flow("client_id", "scope", "audience")
await auth_db.insert_device_flow("client_id", "scope")

monkeypatch.setattr(secrets, "choice", lambda _: "B")

async with auth_db as auth_db:
code, device = await auth_db.insert_device_flow(
"client_id", "scope", "audience"
"client_id",
"scope",
)
assert code == "B" * USER_CODE_LENGTH
assert device
Expand All @@ -59,10 +61,12 @@ async def test_device_flow_lookup(auth_db: AuthDB, monkeypatch):
# First insert
async with auth_db as auth_db:
user_code1, device_code1 = await auth_db.insert_device_flow(
"client_id1", "scope1", "audience1"
"client_id1",
"scope1",
)
user_code2, device_code2 = await auth_db.insert_device_flow(
"client_id2", "scope2", "audience2"
"client_id2",
"scope2",
)

assert user_code1 != user_code2
Expand Down Expand Up @@ -123,7 +127,8 @@ async def test_device_flow_insert_id_token(auth_db: AuthDB):
# First insert
async with auth_db as auth_db:
user_code, device_code = await auth_db.insert_device_flow(
"client_id", "scope", "audience"
"client_id",
"scope",
)

# Make sure it exists, and is Pending
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ async def authorization_flow(
uuid = await auth_db.insert_authorization_flow(
client_id,
scope,
"audience",
code_challenge,
code_challenge_method,
redirect_uri,
Expand Down
5 changes: 1 addition & 4 deletions diracx-routers/src/diracx/routers/auth/device_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ class InitiateDeviceFlowResponse(TypedDict):
async def initiate_device_flow(
client_id: str,
scope: str,
audience: str,
request: Request,
auth_db: AuthDB,
config: Config,
Expand Down Expand Up @@ -126,9 +125,7 @@ async def initiate_device_flow(
detail=e.args[0],
) from e

user_code, device_code = await auth_db.insert_device_flow(
client_id, scope, audience
)
user_code, device_code = await auth_db.insert_device_flow(client_id, scope)

verification_uri = str(request.url.replace(query={}))

Expand Down
3 changes: 2 additions & 1 deletion diracx-routers/src/diracx/routers/auth/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,10 +390,11 @@ async def exchange_token(
}

# Generate access token payload
# For now, the access token is only used to access DIRAC services,
# therefore, the audience is not set and checked
access_payload = {
"sub": sub,
"vo": vo,
"aud": settings.token_audience,
"iss": issuer,
"dirac_properties": parsed_scope["properties"],
"jti": str(uuid4()),
Expand Down
10 changes: 5 additions & 5 deletions diracx-routers/src/diracx/routers/auth/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ class AuthSettings(ServiceSettingsBase, env_prefix="DIRACX_SERVICE_AUTH_"):
state_key: FernetKey

token_issuer: str = "http://lhcbdirac.cern.ch/"
token_audience: str = "dirac"
token_key: TokenSigningKey
token_algorithm: str = "RS256"
access_token_expire_minutes: int = 20
Expand Down Expand Up @@ -131,7 +130,8 @@ async def fetch_jwk_set(url: str):
return JsonWebKey.import_key_set(jwk_set)


async def parse_id_token(config, vo, raw_id_token: str, audience: str):
async def parse_id_token(config, vo, raw_id_token: str):
"""Parse and validate the ID token from IAM."""
server_metadata = await get_server_metadata(
config.Registry[vo].IdP.server_metadata_url
)
Expand All @@ -144,7 +144,9 @@ async def parse_id_token(config, vo, raw_id_token: str, audience: str):
claims_cls=IDToken,
claims_options={
"iss": {"values": [server_metadata["issuer"]]},
"aud": {"values": [audience]},
# The audience is a required parameter and is the client ID of the application
# https://openid.net/specs/openid-connect-core-1_0.html#IDToken
"aud": {"values": [config.Registry[vo].IdP.ClientID]},
},
)
token.validate()
Expand Down Expand Up @@ -195,7 +197,6 @@ async def verify_dirac_access_token(
key=settings.token_key.jwk,
claims_options={
"iss": {"values": [settings.token_issuer]},
"aud": {"values": [settings.token_audience]},
},
)
token.validate()
Expand Down Expand Up @@ -394,7 +395,6 @@ async def get_token_from_iam(
config=config,
vo=vo,
raw_id_token=raw_id_token,
audience=config.Registry[vo].IdP.ClientID,
)
except OAuthError:
raise
Expand Down
6 changes: 1 addition & 5 deletions diracx-routers/tests/auth/test_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def custom_response(request: httpx.Request):
_server_metadata_cache.clear()


async def fake_parse_id_token(raw_id_token: str, audience: str, *args, **kwargs):
async def fake_parse_id_token(raw_id_token: str, *args, **kwargs):
"""Return a fake ID token as if it were returned by an external IdP"""
id_tokens = {
"user1": {
Expand Down Expand Up @@ -183,7 +183,6 @@ async def test_device_flow(test_client, auth_httpx_mock: HTTPXMock):
"/api/auth/device",
params={
"client_id": DIRAC_CLIENT_ID,
"audience": "Dirac server",
"scope": "vo:lhcb group:lhcb_user property:NormalUser",
},
)
Expand Down Expand Up @@ -288,7 +287,6 @@ async def test_flows_with_unallowed_properties(test_client):
"/api/auth/device",
params={
"client_id": DIRAC_CLIENT_ID,
"audience": "Dirac server",
"scope": f"vo:lhcb group:lhcb_user property:{unallowed_property} property:NormalUser",
},
)
Expand Down Expand Up @@ -331,7 +329,6 @@ async def test_flows_with_invalid_properties(test_client):
"/api/auth/device",
params={
"client_id": DIRAC_CLIENT_ID,
"audience": "Dirac server",
"scope": f"vo:lhcb group:lhcb_user property:{invalid_property} property:NormalUser",
},
)
Expand Down Expand Up @@ -661,7 +658,6 @@ def _get_tokens(
"/api/auth/device",
params={
"client_id": DIRAC_CLIENT_ID,
"audience": "Dirac server",
"scope": f"vo:lhcb group:{group} property:{property}",
},
)
Expand Down
2 changes: 0 additions & 2 deletions diracx-testing/src/diracx/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,6 @@ def normal_user(self):
"sub": "testingVO:yellow-sub",
"exp": datetime.now(tz=timezone.utc)
+ timedelta(self.test_auth_settings.access_token_expire_minutes),
"aud": AUDIENCE,
"iss": ISSUER,
"dirac_properties": [NORMAL_USER],
"jti": str(uuid4()),
Expand All @@ -303,7 +302,6 @@ def admin_user(self):
with self.unauthenticated() as client:
payload = {
"sub": "testingVO:yellow-sub",
"aud": AUDIENCE,
"iss": ISSUER,
"dirac_properties": [JOB_ADMINISTRATOR],
"jti": str(uuid4()),
Expand Down
1 change: 0 additions & 1 deletion tests/make-token-local.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def main(token_key):
access_payload = {
"sub": f"{vo}:{sub}",
"vo": vo,
"aud": settings.token_audience,
"iss": settings.token_issuer,
"dirac_properties": dirac_properties,
"jti": str(uuid.uuid4()),
Expand Down

0 comments on commit a1d4c6f

Please sign in to comment.