Skip to content

Commit 2f25e8b

Browse files
committed
🏗️(project) migrate to pydantic v2 and switch tests to polyfactory
Migrating to `pydantic` v2 should speed up processing and allow interoperability with projects such as `warren`. This migration makes the hypothesis package used in tests obsolete, which is why we introduce `polyfactory`.
1 parent 361f395 commit 2f25e8b

File tree

157 files changed

+2621
-2420
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

157 files changed

+2621
-2420
lines changed

CHANGELOG.md

+5
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@ and this project adheres to
88

99
## [Unreleased]
1010

11+
### Changed
12+
13+
- Upgrade `pydantic` to `2.7.0`
14+
- Migrate model tests from hypothesis strategies to polyfactory
15+
1116
## [4.2.0] - 2024-04-08
1217

1318
### Added

pyproject.toml

+3-2
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ dependencies = [
3232
# library (mostly models).
3333
"importlib-metadata>=7.0.1, <8.0",
3434
"langcodes>=3.2.0",
35-
"pydantic[dotenv,email]>=1.10.0, <2.0",
35+
"pydantic[email]>=2.5.3,<3.0",
36+
"pydantic_settings>=2.1.0,<3.0",
3637
"rfc3987>=1.3.0",
3738
]
3839
dynamic = ["version"]
@@ -92,7 +93,6 @@ dev = [
9293
"black==24.3.0",
9394
"cryptography==42.0.5",
9495
"factory-boy==3.3.0",
95-
"hypothesis<6.92.0", # pin as hypothesis 6.92.0 observability feature seems broken
9696
"logging-gelf==0.0.32",
9797
"mike==2.0.0",
9898
"mkdocs==1.5.3",
@@ -102,6 +102,7 @@ dev = [
102102
"moto==5.0.5",
103103
"mypy==1.9.0",
104104
"neoteroi-mkdocs==1.0.5",
105+
"polyfactory==2.15.0",
105106
"pyfakefs==5.4.0",
106107
"pymdown-extensions==10.7.1",
107108
"pytest<8.0.0", # pin as pytest-httpx<0.23.0 is not compatible with pytest 8.0.0

src/ralph/api/__init__.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,7 @@ async def whoami(
5151
user: AuthenticatedUser = Depends(get_authenticated_user),
5252
) -> Dict[str, Any]:
5353
"""Return the current user's username along with their scopes."""
54-
return {"agent": user.agent, "scopes": user.scopes}
54+
return {
55+
"agent": user.agent.model_dump(mode="json", exclude_none=True),
56+
"scopes": user.scopes,
57+
}

src/ralph/api/auth/basic.py

+16-16
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Basic authentication & authorization related tools for the Ralph API."""
22

33
import logging
4+
import os
45
from functools import lru_cache
56
from pathlib import Path
67
from threading import Lock
@@ -10,7 +11,7 @@
1011
from cachetools import TTLCache, cached
1112
from fastapi import Depends, HTTPException, status
1213
from fastapi.security import HTTPBasic, HTTPBasicCredentials
13-
from pydantic import BaseModel, root_validator
14+
from pydantic import RootModel, model_validator
1415
from starlette.authentication import AuthenticationError
1516

1617
from ralph.api.auth.user import AuthenticatedUser
@@ -40,45 +41,42 @@ class UserCredentials(AuthenticatedUser):
4041
username: str
4142

4243

43-
class ServerUsersCredentials(BaseModel):
44+
class ServerUsersCredentials(RootModel[List[UserCredentials]]):
4445
"""Custom root pydantic model.
4546
4647
Describe expected list of all server users credentials as stored in
4748
the credentials file.
4849
4950
Attributes:
50-
__root__ (List): Custom root consisting of the
51+
root (List): Custom root consisting of the
5152
list of all server users credentials.
5253
"""
5354

54-
__root__: List[UserCredentials]
55-
5655
def __add__(self, other) -> Any: # noqa: D105
57-
return ServerUsersCredentials.parse_obj(self.__root__ + other.__root__)
56+
return ServerUsersCredentials.model_validate(self.root + other.root)
5857

5958
def __getitem__(self, item: int) -> UserCredentials: # noqa: D105
60-
return self.__root__[item]
59+
return self.root[item]
6160

6261
def __len__(self) -> int: # noqa: D105
63-
return len(self.__root__)
62+
return len(self.root)
6463

6564
def __iter__(self) -> Iterator[UserCredentials]: # noqa: D105
66-
return iter(self.__root__)
65+
return iter(self.root)
6766

68-
@root_validator
69-
@classmethod
70-
def ensure_unique_username(cls, values: Any) -> Any:
67+
@model_validator(mode="after")
68+
def ensure_unique_username(self) -> Any:
7169
"""Every username should be unique among registered users."""
72-
usernames = [entry.username for entry in values.get("__root__")]
70+
usernames = [entry.username for entry in self.root]
7371
if len(usernames) != len(set(usernames)):
7472
raise ValueError(
7573
"You cannot create multiple credentials with the same username"
7674
)
77-
return values
75+
return self
7876

7977

8078
@lru_cache()
81-
def get_stored_credentials(auth_file: Path) -> ServerUsersCredentials:
79+
def get_stored_credentials(auth_file: os.PathLike) -> ServerUsersCredentials:
8280
"""Helper to read the credentials/scopes file.
8381
8482
Read credentials from JSON file and stored them to avoid reloading them with every
@@ -96,7 +94,9 @@ def get_stored_credentials(auth_file: Path) -> ServerUsersCredentials:
9694
msg = "Credentials file <%s> not found."
9795
logger.warning(msg, auth_file)
9896
raise AuthenticationError(msg.format(auth_file))
99-
return ServerUsersCredentials.parse_file(auth_file)
97+
98+
with open(auth_file, encoding=settings.LOCALE_ENCODING) as f:
99+
return ServerUsersCredentials.model_validate_json(f.read())
100100

101101

102102
@cached(

src/ralph/api/auth/oidc.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from fastapi.security import HTTPBearer, OpenIdConnect
1010
from jose import ExpiredSignatureError, JWTError, jwt
1111
from jose.exceptions import JWTClaimsError
12-
from pydantic import AnyUrl, BaseModel, Extra
12+
from pydantic import AnyUrl, BaseModel, ConfigDict
1313
from typing_extensions import Annotated
1414

1515
from ralph.api.auth.user import AuthenticatedUser, UserScopes
@@ -45,14 +45,13 @@ class IDToken(BaseModel):
4545

4646
iss: str
4747
sub: str
48-
aud: Optional[str]
48+
aud: Optional[str] = None
4949
exp: int
5050
iat: int
51-
scope: Optional[str]
52-
target: Optional[str]
51+
scope: Optional[str] = None
52+
target: Optional[str] = None
5353

54-
class Config: # noqa: D106
55-
extra = Extra.ignore
54+
model_config = ConfigDict(extra="ignore")
5655

5756

5857
@lru_cache()
@@ -144,7 +143,7 @@ def get_oidc_user(
144143
headers={"WWW-Authenticate": "Bearer"},
145144
) from exc
146145

147-
id_token = IDToken.parse_obj(decoded_token)
146+
id_token = IDToken.model_validate(decoded_token)
148147

149148
user = AuthenticatedUser(
150149
agent={"openid": f"{id_token.iss}/{id_token.sub}"},

src/ralph/api/auth/user.py

+8-14
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
"""Authenticated user for the Ralph API."""
22

3-
from typing import Dict, FrozenSet, Literal, Optional
3+
from typing import FrozenSet, Literal, Optional
44

5-
from pydantic import BaseModel
5+
from pydantic import BaseModel, RootModel
6+
7+
from ralph.models.xapi.base.agents import BaseXapiAgent
68

79
Scope = Literal[
810
"statements/write",
@@ -18,7 +20,7 @@
1820
]
1921

2022

21-
class UserScopes(FrozenSet[Scope]):
23+
class UserScopes(RootModel[FrozenSet[Scope]]):
2224
"""Scopes available to users."""
2325

2426
def is_authorized(self, requested_scope: Scope):
@@ -47,19 +49,11 @@ def is_authorized(self, requested_scope: Scope):
4749
}
4850

4951
expanded_user_scopes = set()
50-
for scope in self:
52+
for scope in self.root:
5153
expanded_user_scopes.update(expanded_scopes.get(scope, {scope}))
5254

5355
return requested_scope in expanded_user_scopes
5456

55-
@classmethod
56-
def __get_validators__(cls): # noqa: D105
57-
def validate(value: FrozenSet[Scope]):
58-
"""Transform value to an instance of UserScopes."""
59-
return cls(value)
60-
61-
yield validate
62-
6357

6458
class AuthenticatedUser(BaseModel):
6559
"""Pydantic model for user authentication.
@@ -70,6 +64,6 @@ class AuthenticatedUser(BaseModel):
7064
target (str or None): The target index or database to store statements into.
7165
"""
7266

73-
agent: Dict
67+
agent: BaseXapiAgent
7468
scopes: UserScopes
75-
target: Optional[str]
69+
target: Optional[str] = None

src/ralph/api/forwarding.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -42,16 +42,16 @@ async def forward_xapi_statements(
4242
try:
4343
# NB: post or put
4444
req = await getattr(client, method)(
45-
forwarding.url,
45+
str(forwarding.url),
4646
json=statements,
4747
auth=(forwarding.basic_username, forwarding.basic_password),
4848
timeout=forwarding.timeout,
4949
)
5050
req.raise_for_status()
5151
msg = "Forwarded %s statements to %s with success."
5252
if isinstance(statements, list):
53-
logger.debug(msg, len(statements), forwarding.url)
53+
logger.debug(msg, len(statements), str(forwarding.url))
5454
else:
55-
logger.debug(msg, 1, forwarding.url)
55+
logger.debug(msg, 1, str(forwarding.url))
5656
except (RequestError, HTTPStatusError) as error:
5757
logger.error("Failed to forward xAPI statements. %s", error)

src/ralph/api/models.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing import Optional, Union
88
from uuid import UUID
99

10-
from pydantic import AnyUrl, BaseModel, Extra
10+
from pydantic import AnyUrl, BaseModel, ConfigDict
1111

1212
from ..models.xapi.base.agents import BaseXapiAgent
1313
from ..models.xapi.base.groups import BaseXapiGroup
@@ -30,13 +30,7 @@ class BaseModelWithLaxConfig(BaseModel):
3030
we receive statements through the API.
3131
"""
3232

33-
class Config:
34-
"""Enable extra properties.
35-
36-
Useful for not having to perform comprehensive validation.
37-
"""
38-
39-
extra = Extra.allow
33+
model_config = ConfigDict(extra="allow", coerce_numbers_to_str=True)
4034

4135

4236
class LaxObjectField(BaseModelWithLaxConfig):
@@ -65,6 +59,6 @@ class LaxStatement(BaseModelWithLaxConfig):
6559
"""
6660

6761
actor: Union[BaseXapiAgent, BaseXapiGroup]
68-
id: Optional[UUID]
62+
id: Optional[UUID] = None
6963
object: LaxObjectField
7064
verb: LaxVerbField

src/ralph/api/routers/health.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ async def heartbeat(response: Response) -> Heartbeat:
4747
4848
Return a 200 if all checks are successful.
4949
"""
50-
statuses = Heartbeat.construct(
50+
statuses = Heartbeat.model_construct(
5151
database=await await_if_coroutine(BACKEND_CLIENT.status())
5252
)
5353
if not statuses.is_alive:

src/ralph/api/routers/statements.py

+25-11
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
status,
2020
)
2121
from fastapi.dependencies.models import Dependant
22-
from pydantic import parse_obj_as
22+
from pydantic import TypeAdapter
2323
from pydantic.types import Json
2424
from typing_extensions import Annotated
2525

@@ -98,14 +98,17 @@ def _enrich_statement_with_authority(
9898
) -> None:
9999
# authority: Information about whom or what has asserted the statement is true.
100100
# https://github.com/adlnet/xAPI-Spec/blob/master/xAPI-Data.md#249-authority
101-
statement["authority"] = current_user.agent
101+
statement["authority"] = current_user.agent.model_dump(
102+
exclude_none=True, mode="json"
103+
)
102104

103105

104106
def _parse_agent_parameters(agent_obj: dict) -> AgentParameters:
105107
"""Parse a dict and return an AgentParameters object to use in queries."""
106108
# Transform agent to `dict` as FastAPI cannot parse JSON (seen as string)
107109

108-
agent = parse_obj_as(BaseXapiAgent, agent_obj)
110+
adapter = TypeAdapter(BaseXapiAgent)
111+
agent = adapter.validate_python(agent_obj)
109112

110113
agent_query_params = {}
111114
if isinstance(agent, BaseXapiAgentWithMbox):
@@ -119,7 +122,7 @@ def _parse_agent_parameters(agent_obj: dict) -> AgentParameters:
119122
agent_query_params["account__home_page"] = agent.account.homePage
120123

121124
# Overwrite `agent` field
122-
return AgentParameters.construct(**agent_query_params)
125+
return AgentParameters.model_construct(**agent_query_params)
123126

124127

125128
def strict_query_params(request: Request) -> None:
@@ -141,7 +144,7 @@ def strict_query_params(request: Request) -> None:
141144

142145
@router.get("")
143146
@router.get("/")
144-
async def get( # noqa: PLR0913
147+
async def get( # noqa: PLR0912,PLR0913
145148
request: Request,
146149
current_user: Annotated[
147150
AuthenticatedUser,
@@ -169,7 +172,7 @@ async def get( # noqa: PLR0913
169172
None,
170173
description="Filter, only return Statements matching the specified Verb id",
171174
),
172-
activity: Optional[IRI] = Query(
175+
activity: Optional[str] = Query(
173176
None,
174177
description=(
175178
"Filter, only return Statements for which the Object "
@@ -334,7 +337,14 @@ async def get( # noqa: PLR0913
334337
# Overwrite `agent` field
335338
query_params["agent"] = _parse_agent_parameters(
336339
json.loads(query_params["agent"])
337-
)
340+
).model_dump(mode="json", exclude_none=True)
341+
342+
# Coerce `verb` and `activity` as IRI
343+
if query_params.get("verb"):
344+
query_params["verb"] = IRI(query_params["verb"])
345+
346+
if query_params.get("activity"):
347+
query_params["activity"] = IRI(query_params["activity"])
338348

339349
# mine: If using scopes, only restrict users with limited scopes
340350
if settings.LRS_RESTRICT_BY_SCOPES:
@@ -346,7 +356,9 @@ async def get( # noqa: PLR0913
346356

347357
# Filter by authority if using `mine`
348358
if mine:
349-
query_params["authority"] = _parse_agent_parameters(current_user.agent)
359+
query_params["authority"] = _parse_agent_parameters(
360+
current_user.agent.model_dump(mode="json")
361+
).model_dump(mode="json", exclude_none=True)
350362

351363
if "mine" in query_params:
352364
query_params.pop("mine")
@@ -355,7 +367,7 @@ async def get( # noqa: PLR0913
355367
try:
356368
query_result = await await_if_coroutine(
357369
BACKEND_CLIENT.query_statements(
358-
params=RalphStatementsQuery.construct(
370+
params=RalphStatementsQuery.model_construct(
359371
**{**query_params, "limit": limit}
360372
),
361373
target=current_user.target,
@@ -418,7 +430,7 @@ async def put(
418430
LRS Specification:
419431
https://github.com/adlnet/xAPI-Spec/blob/1.0.3/xAPI-Communication.md#211-put-statements
420432
"""
421-
statement_as_dict = statement.dict(exclude_unset=True)
433+
statement_as_dict = statement.model_dump(exclude_unset=True, mode="json")
422434
statement_id = str(statement_id)
423435

424436
statement_as_dict.update(id=str(statement_as_dict.get("id", statement_id)))
@@ -516,7 +528,9 @@ async def post( # noqa: PLR0912
516528

517529
# Enrich statements before forwarding
518530
statements_dict = {}
519-
for statement in (x.dict(exclude_unset=True) for x in statements):
531+
for statement in (
532+
x.model_dump(exclude_unset=True, mode="json") for x in statements
533+
):
520534
_enrich_statement_with_id(statement)
521535
# Requests with duplicate statement IDs are considered invalid
522536
if statement["id"] in statements_dict:

0 commit comments

Comments
 (0)